mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 03:34:56 +08:00
Compare commits
44 Commits
csl/xml_st
...
ciflow/pul
| Author | SHA1 | Date | |
|---|---|---|---|
| b552a4eba1 | |||
| b3e120665b | |||
| 0d4992c170 | |||
| b060e5c131 | |||
| 6d5e651a50 | |||
| 3cc5949dc2 | |||
| f167fd09fa | |||
| 68b3984b77 | |||
| a1eb6b5538 | |||
| f36f372acc | |||
| d9483d4c8d | |||
| fea819ed08 | |||
| 84a2715d34 | |||
| 572cc12b42 | |||
| 1fdef664a5 | |||
| 08ae55021e | |||
| 96a8d1c5e0 | |||
| 39307c3db2 | |||
| 3d6061d56a | |||
| dc55769bb6 | |||
| 551921d484 | |||
| 2c74beddf6 | |||
| 12ff17857e | |||
| b5189e269e | |||
| 3895ce093f | |||
| 8aa087a29d | |||
| 7379972cc0 | |||
| 4ae3c59ce2 | |||
| 7a8ad5f874 | |||
| dd09fa089d | |||
| 0995593caa | |||
| 69a4358a01 | |||
| 0ab9e050ab | |||
| 651e9dbf94 | |||
| 56bd4c695a | |||
| 1cb7be9419 | |||
| 00f68803d3 | |||
| de1f732075 | |||
| cbfee32779 | |||
| 0e38867920 | |||
| cefd269c35 | |||
| 7fcf3a1488 | |||
| 9a88bd06e1 | |||
| ccc9750df1 |
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*mx8mx8bf16_grouped.*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
@ -291,6 +291,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
|
||||
set(fbgemm_genai_cuh
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/f4f4bf16_grouped/"
|
||||
"${FBGEMM_GENAI_SRCS}/"
|
||||
)
|
||||
|
||||
|
||||
@ -208,6 +208,48 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_f4_f4_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& global_scale_a,
|
||||
const Tensor& scale_b,
|
||||
const Tensor& global_scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
Tensor& out) {
|
||||
#if !defined(USE_ROCM) && defined(USE_FBGEMM_GENAI)
|
||||
// Typing checks
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
|
||||
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
|
||||
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
|
||||
|
||||
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out,
|
||||
global_scale_a.mul(global_scale_b)
|
||||
);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
|
||||
#endif
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
@ -245,7 +287,15 @@ void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int
|
||||
}
|
||||
}
|
||||
|
||||
void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
void _check_scales_blocked(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) {
|
||||
// if {mx,nv}fp4, will need to modify K later
|
||||
bool is_fp4 = (mat.scalar_type() == kFloat4_e2m1fn_x2);
|
||||
int blocksize = 32;
|
||||
// check for nvfp4 vs. mxfp4 to fix blocksize
|
||||
if (is_fp4 && scale.scalar_type() == kFloat8_e4m3fn) {
|
||||
blocksize = 16;
|
||||
}
|
||||
|
||||
// Checks scales for 2d or 3d target tensors (`mat`).
|
||||
if (mat.dim() == 2) {
|
||||
// For MXFP8, 2d tensors have variable size groups represented as subtensors,
|
||||
@ -253,17 +303,19 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
|
||||
// so we can't check the scale sizes without doing a d2h sync to get the group sizes here.
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim(),
|
||||
"for mxfp8, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(), " and scale.dim() = ", scale.dim(), " for arg ", arg_idx);
|
||||
"for block-scaled, scale must have same number of dimensions as parent tensor, but got mat.dim() = ", mat.dim(),
|
||||
" and scale.dim() = ", scale.dim(), " for arg ", arg_idx
|
||||
);
|
||||
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/32, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/32, 4))
|
||||
// LHS mat shape (M, total_K) -> scale shape (rounded_up(M, 128), rounded_up_per_group(K/blocksize, 4))
|
||||
// RHS mat shape (total_K, N) -> scale shape (rounded_up(N, 128), rounded_up_per_group(K/blocksize, 4))
|
||||
// * weight is transposed prior to the call, scale stays non-transposed.
|
||||
bool LHS = arg_idx == 0;
|
||||
int scale_dim_to_check = 0;
|
||||
int mat_dim_to_check = LHS ? 0 : 1;
|
||||
TORCH_CHECK(
|
||||
scale.size(scale_dim_to_check) >= mat.size(mat_dim_to_check),
|
||||
"for mxfp8, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"for block-scaled, arg ", arg_idx, " tensor shape (", mat.size(0), ", ", mat.size(1), ") ",
|
||||
"must have scale.shape[", scale_dim_to_check, "] >= ", mat.size(mat_dim_to_check), " but got scale.shape=(", scale.size(0), ", ", scale.size(1), ")");
|
||||
} else {
|
||||
// For MXFP8, 3d tensors have static group sizes (stack of 2d tensors),
|
||||
@ -273,32 +325,40 @@ void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim,
|
||||
};
|
||||
|
||||
// TODO: this is for 3d tensor in 2d-3d case specifically.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8 grouped gemm supports them.
|
||||
// We'll need to support 3d-3d and 3d-2d cases once mxfp8/nvfp4 grouped gemm supports them.
|
||||
int64_t G = mat.size(0);
|
||||
int64_t K = mat.size(1);
|
||||
if (is_fp4) {
|
||||
// FP4 packs 2 values into a single 8b word - the "real" K is 2x the
|
||||
// reported K. Reverse that adjustment.
|
||||
const int fp4_elems_per_byte = 2;
|
||||
K *= fp4_elems_per_byte;
|
||||
}
|
||||
int64_t N = mat.size(2);
|
||||
int64_t blocked_scale_K = round_up(K/32, 4);
|
||||
int64_t blocked_scale_K = round_up(K/blocksize, 4);
|
||||
int64_t blocked_scale_N = round_up(N, 128);
|
||||
|
||||
// fbgemm expects stack of flattened blocked scales for 3d tensor, shape (G, blocked_scale_K * blocked_scale_N).
|
||||
TORCH_CHECK(
|
||||
scale.dim() == mat.dim() - 1,
|
||||
"for mxfp8 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N), but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
"for block-scaled 2d-3d grouped GEMM, the 3d tensor of shape (G,K,N) must have a 2d scale of shape (G, blocked_scale_K * blocked_scale_N),",
|
||||
"but scale is ", scale.dim(), "D for arg ", arg_idx
|
||||
);
|
||||
TORCH_CHECK(
|
||||
scale.size(0) == G && scale.size(1) == blocked_scale_K * blocked_scale_N,
|
||||
"for mxfp8, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ") for arg ", arg_idx
|
||||
"for block-scaled grouped GEMM, the tensor shape (", G, ", ", K, ", ", N, ") must have scale shape (", G, ",", blocked_scale_K, ",", blocked_scale_N, ")",
|
||||
" for arg ", arg_idx, ", got: ", scale.size(0), ", ", scale.size(1)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
|
||||
bool using_fp8_rowwise = scale.scalar_type() == kFloat;
|
||||
bool using_mxfp8 = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
bool using_mx = scale.scalar_type() == at::kFloat8_e8m0fnu;
|
||||
if (using_fp8_rowwise) {
|
||||
_check_scales_fp8_rowwise(mat, scale, dim, arg_idx, scale_multiplier);
|
||||
} else if (using_mxfp8) {
|
||||
_check_scales_mxfp8(mat, scale, dim, arg_idx);
|
||||
} else if (using_mx) {
|
||||
_check_scales_blocked(mat, scale, dim, arg_idx);
|
||||
} else {
|
||||
TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype());
|
||||
}
|
||||
@ -411,9 +471,10 @@ namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@ -525,8 +586,9 @@ _scaled_grouped_mm_cuda_v2(
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -537,6 +599,21 @@ _scaled_grouped_mm_cuda_v2(
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::NVFP4_NVFP4: {
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0], /* block-scale A */
|
||||
scale_a[1], /* global-scale A */
|
||||
scale_b[0], /* block-scale B */
|
||||
scale_b[1], /* global-scale B */
|
||||
offs.value(),
|
||||
std::nullopt, /* bias */
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/reshape.h>
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#include <ATen/ops/sum.h>
|
||||
@ -42,7 +43,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#include <static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace FLASH_NAMESPACE {
|
||||
@ -417,6 +417,26 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
|
||||
if (batch_size == 0) {
|
||||
auto opts = q.options();
|
||||
at::Tensor out = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
|
||||
at::Tensor q_padded = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
|
||||
at::Tensor k_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
|
||||
at::Tensor v_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
|
||||
at::Tensor softmax_lse = at::empty({0, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
at::Tensor _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
at::Tensor p = at::empty({0}, opts);
|
||||
if (return_softmax) {
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
p = at::empty({0, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
|
||||
}
|
||||
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), _unused, std::move(p)};
|
||||
}
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
@ -547,7 +567,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
||||
}
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
|
||||
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), std::move(_unused), std::move(p)};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@ -852,7 +872,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
@ -863,6 +882,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const int head_size = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
|
||||
if (batch_size == 0) {
|
||||
auto opts = q.options();
|
||||
at::Tensor dq = at::empty_like(q);
|
||||
at::Tensor dk = at::empty_like(k);
|
||||
at::Tensor dv = at::empty_like(v);
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
at::Tensor softmax_d = at::empty({0, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
return {dq, dk, dv, softmax_d};
|
||||
}
|
||||
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
|
||||
@ -1066,6 +1066,8 @@ coverage_ignore_functions = [
|
||||
"set_current_meta",
|
||||
"set_grad_fn_seq_nr",
|
||||
"set_stack_trace",
|
||||
"set_current_replay_node",
|
||||
"get_current_replay_node",
|
||||
# torch.jit.annotations
|
||||
"ann_to_type",
|
||||
"check_fn",
|
||||
|
||||
@ -99,6 +99,12 @@ DTensor supports the following types of {class}`Placement` on each {class}`Devic
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: MaskPartial
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: Placement
|
||||
:members:
|
||||
|
||||
@ -22,7 +22,11 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -32,6 +36,7 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.utils._pytree import register_pytree_node
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
@ -176,6 +181,15 @@ def _count_op(gm, target):
|
||||
return sum(1 for node in gm.graph.nodes if node.target == target)
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
BlockMask,
|
||||
BlockMask._flatten,
|
||||
BlockMask._unflatten,
|
||||
flatten_with_keys_fn=BlockMask._flatten_with_keys,
|
||||
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
|
||||
)
|
||||
|
||||
|
||||
@requires_cuda
|
||||
class DTensorExportTest(TestCase):
|
||||
def tearDown(self):
|
||||
|
||||
@ -168,7 +168,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
|
||||
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
|
||||
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor.placement_types import MaskPartial
|
||||
|
||||
# test collectives
|
||||
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
|
||||
@ -176,7 +176,7 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
inp = torch.randint(0, 10, (8, 8), device=self.device_type)
|
||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||
output = sharded_embedding(replicated_inp)
|
||||
self.assertIsInstance(output.placements[0], _MaskPartial)
|
||||
self.assertIsInstance(output.placements[0], MaskPartial)
|
||||
|
||||
comm_mode = CommDebugMode()
|
||||
|
||||
@ -192,9 +192,9 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
inp = torch.randint(0, 10, (4, 4), device=self.device_type)
|
||||
replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
|
||||
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor.placement_types import MaskPartial
|
||||
|
||||
# case 1: two embeddings with the same shape, thus sharing the underlying _MaskPartial
|
||||
# case 1: two embeddings with the same shape, thus sharing the underlying MaskPartial
|
||||
# and MaskBuffer, because of cache hit from sharding propagation
|
||||
|
||||
emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
|
||||
@ -206,23 +206,23 @@ class TestEmbeddingOp(DTensorTestBase):
|
||||
output2 = sharded_emb2(replicated_inp)
|
||||
|
||||
partial_placement1 = output1.placements[0]
|
||||
self.assertIsInstance(partial_placement1, _MaskPartial)
|
||||
self.assertIsInstance(partial_placement1, MaskPartial)
|
||||
output1.full_tensor()
|
||||
|
||||
partial_placement2 = output2.placements[0]
|
||||
self.assertIsInstance(partial_placement2, _MaskPartial)
|
||||
self.assertIsInstance(partial_placement2, MaskPartial)
|
||||
output2.full_tensor()
|
||||
|
||||
self.assertTrue(id(partial_placement1), id(partial_placement2))
|
||||
|
||||
# case 2: two embeddings with the same logical_dim_size, but different logical_shape
|
||||
# thus they will have different _MaskPartial placements (with no cache hit)
|
||||
# thus they will have different MaskPartial placements (with no cache hit)
|
||||
|
||||
emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
|
||||
sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
|
||||
output3 = sharded_emb3(replicated_inp)
|
||||
partial_placement3 = output3.placements[0]
|
||||
self.assertIsInstance(partial_placement3, _MaskPartial)
|
||||
self.assertIsInstance(partial_placement3, MaskPartial)
|
||||
output2.full_tensor()
|
||||
|
||||
# not equal because of different logical_shape, despite of same logical_dim_size
|
||||
|
||||
@ -511,7 +511,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
# case 2 input sharding: input sharded, index replicated, output mask partial
|
||||
# only works when index has size 1 on the gather dimension and
|
||||
# input is sharded on the gather dimension
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor.placement_types import MaskPartial
|
||||
|
||||
gather_dim = 1
|
||||
global_input = torch.randn(12, 8, 16)
|
||||
@ -522,7 +522,7 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
with comm_mode:
|
||||
output_dt = torch.gather(input_dt, gather_dim, index_dt)
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
self.assertIsInstance(output_dt.placements[0], _MaskPartial)
|
||||
self.assertIsInstance(output_dt.placements[0], MaskPartial)
|
||||
self.assertEqual(output_dt.full_tensor(), global_output)
|
||||
|
||||
# case 3 index sharding: input replicated, index sharded, output sharded
|
||||
|
||||
@ -230,7 +230,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 12)
|
||||
self.assertEqual(cnts.op_count, 20)
|
||||
|
||||
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@ -335,7 +335,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 21)
|
||||
self.assertEqual(cnts.op_count, 37)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_stream_compared_with_constant(self):
|
||||
@ -517,7 +517,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x, cur_stream, new_stream)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 19)
|
||||
self.assertEqual(cnts.op_count, 27)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_event_method(self):
|
||||
@ -557,7 +557,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 19)
|
||||
self.assertEqual(cnts.op_count, 27)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_device(self):
|
||||
|
||||
@ -1016,6 +1016,59 @@ class inner_f(torch.nn.Module):
|
||||
self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None))
|
||||
self.assertFalse("self._opoverload" in gm.print_readable(print_output=False))
|
||||
|
||||
def test_preserve_annotate_replay_view(self):
|
||||
"""Test stack trace and annotation are correct on nodes regenerated in functionalization"""
|
||||
|
||||
def _unpermute(out, input_shape, permuted_indices):
|
||||
"""
|
||||
Unpermute operation from torchtitan MoE utils.
|
||||
"""
|
||||
out_unpermuted = out.new_empty(input_shape)
|
||||
out_unpermuted[permuted_indices, :] = out
|
||||
out = out_unpermuted[:-1]
|
||||
return out
|
||||
|
||||
class Module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_shape = (5, 3)
|
||||
self.permuted_indices = torch.tensor([2, 0, 3, 1])
|
||||
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
routed_output = _unpermute(
|
||||
x, self.input_shape, self.permuted_indices
|
||||
)
|
||||
return routed_output.cos()
|
||||
|
||||
inputs = (torch.randn(4, 3, requires_grad=True),)
|
||||
model = Module()
|
||||
|
||||
graph_module = graph_capture(model, inputs, True)
|
||||
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
|
||||
slice_nodes = graph_module.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.slice.Tensor
|
||||
)
|
||||
self.assertEqual(len(slice_nodes), 1)
|
||||
slice_backward_nodes = graph_module.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.slice_backward.default
|
||||
)
|
||||
self.assertEqual(len(slice_backward_nodes), 1)
|
||||
slice_node = slice_nodes[0]
|
||||
slice_backward_node = slice_backward_nodes[0]
|
||||
|
||||
self.assertEqual(slice_node.meta["seq_nr"], slice_backward_node.meta["seq_nr"])
|
||||
self.assertTrue("out = out_unpermuted[:-1]" in slice_node.meta["stack_trace"])
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 'new_empty', {'pp_stage': 0})
|
||||
('call_function', 'index_put', {'pp_stage': 0})
|
||||
('call_function', 'slice_2', {'pp_stage': 0})
|
||||
('call_function', 'slice_backward', {'pp_stage': 0})
|
||||
('call_function', 'index', {'pp_stage': 0})""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -3245,8 +3245,8 @@ def forward(self, primals_1):
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
|
||||
as_strided_9 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_9, [4]); as_strided_9 = None
|
||||
return (as_strided_scatter, view_1)""",
|
||||
) # noqa: B950
|
||||
|
||||
@ -3409,13 +3409,13 @@ def forward(self, primals_1, primals_2, primals_3):
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze); add_1 = None
|
||||
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
|
||||
return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
|
||||
return (as_strided_scatter, add_2, view_2, unsqueeze)""",
|
||||
) # noqa: B950
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
|
||||
@ -14269,6 +14269,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue("'enable_fp_fusion': False" in code)
|
||||
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@config.patch(runtime_triton_nan_asserts=True)
|
||||
def test_nan_assert_inside_triton_kernel(self):
|
||||
def fn(x):
|
||||
x = x - 1
|
||||
# Uncomment the following line can trigger the failure of
|
||||
# the device size assertion
|
||||
# x = torch.log(x)
|
||||
return torch.where(x.isnan(), 3.14, x)
|
||||
|
||||
compiled = torch.compile(fn)
|
||||
x = torch.randn(4096, device=GPU_TYPE)
|
||||
out, (code,) = run_and_get_code(compiled, x)
|
||||
self.assertTrue("'NaN or Inf found'" in code)
|
||||
torch.testing.assert_close(out, fn(x))
|
||||
|
||||
@skip_if_cpp_wrapper("skip cpp wrapper")
|
||||
@requires_cuda_and_triton
|
||||
def test_repeat_interleave_decomposition_has_clamp(self):
|
||||
|
||||
@ -27,7 +27,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import current_process, get_context
|
||||
from torch.testing._internal.common_utils import (
|
||||
get_report_dir,
|
||||
get_report_path,
|
||||
IS_CI,
|
||||
IS_MACOS,
|
||||
@ -35,6 +34,7 @@ from torch.testing._internal.common_utils import (
|
||||
set_cwd,
|
||||
shell,
|
||||
TEST_CUDA,
|
||||
TEST_SAVE_XML,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_SLOW_GRADCHECK,
|
||||
@ -529,14 +529,6 @@ def run_test(
|
||||
replacement = {"-f": "-x", "-dist=loadfile": "--dist=loadfile"}
|
||||
unittest_args = [replacement.get(arg, arg) for arg in unittest_args]
|
||||
|
||||
xml_report_dir = get_report_dir(test_file, None, options.pytest)
|
||||
if is_cpp_test:
|
||||
unittest_args.append(
|
||||
f"--junit-xml-reruns={get_report_path(xml_report_dir, test_file)}"
|
||||
)
|
||||
else:
|
||||
unittest_args.append(f"--save-xml={xml_report_dir}")
|
||||
|
||||
if options.showlocals:
|
||||
if options.pytest:
|
||||
unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"])
|
||||
@ -763,6 +755,8 @@ def run_test_retries(
|
||||
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
|
||||
) as f:
|
||||
current_failure = f.read()
|
||||
if current_failure == "null":
|
||||
current_failure = f"'{test_file}'"
|
||||
except FileNotFoundError:
|
||||
print_to_file(
|
||||
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
|
||||
@ -799,8 +793,6 @@ def run_test_retries(
|
||||
print_to_file("Retrying single test...")
|
||||
print_items = [] # do not continue printing them, massive waste of space
|
||||
|
||||
if "null" in num_failures:
|
||||
num_failures[f"'{test_file}'"] = num_failures.pop("null")
|
||||
consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
|
||||
flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
|
||||
if len(flaky_failures) > 0:
|
||||
@ -1234,6 +1226,12 @@ def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
|
||||
# is much slower than running them directly
|
||||
pytest_args.extend(["-n", str(NUM_PROCS)])
|
||||
|
||||
if TEST_SAVE_XML:
|
||||
# Add the option to generate XML test report here as C++ tests
|
||||
# won't go into common_utils
|
||||
test_report_path = get_report_path(pytest=True)
|
||||
pytest_args.extend(["--junit-xml-reruns", test_report_path])
|
||||
|
||||
if options.pytest_k_expr:
|
||||
pytest_args.extend(["-k", options.pytest_k_expr])
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ from torch.testing._internal.common_quantized import (
|
||||
_floatx_unpacked_to_f32,
|
||||
ceil_div, to_blocked,
|
||||
to_mxfp8,
|
||||
from_blocked_format,
|
||||
generate_jagged_offs,
|
||||
)
|
||||
|
||||
@ -462,6 +463,24 @@ def pack_uint4(uint8_data) -> torch.Tensor:
|
||||
uint8_data = uint8_data.contiguous().view(-1)
|
||||
return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
|
||||
|
||||
def unpack_uint4(uint8_data) -> torch.Tensor:
|
||||
# Take a packed uint8 tensor (i.e. nvfp4) and unpack into
|
||||
# a tensor twice as wide. Useful for dequant operations.
|
||||
shape = list(uint8_data.shape)
|
||||
# 2x packed elements -> single non-packed => adjust shape
|
||||
shape[-1] *= 2
|
||||
out = torch.empty(
|
||||
*shape,
|
||||
device=uint8_data.device,
|
||||
dtype=torch.uint8
|
||||
).view(-1)
|
||||
|
||||
uint8_data_as_uint8 = uint8_data.view(torch.uint8).view(-1)
|
||||
|
||||
out[1::2] = uint8_data_as_uint8[:] >> 4
|
||||
out[::2] = uint8_data_as_uint8 & 15
|
||||
|
||||
return out.view(shape)
|
||||
|
||||
def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||
assert x.dtype == torch.bfloat16
|
||||
@ -470,6 +489,119 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||
x = x.view(torch.float4_e2m1fn_x2)
|
||||
return x
|
||||
|
||||
def _convert_to_nvfp4_with_hp_ref(t):
|
||||
# Convert a tensor to nvfp4, returning:
|
||||
# t_hp : reconstructed bf16 version of t_lp
|
||||
# t_lp : nvfp4 tensor (2x elements packed into uint8)
|
||||
# t_scale: e4m3 block-wise scaling factors (non-swizzled)
|
||||
# t_global_scale: fp32 tensor-wise global scaling factor
|
||||
t_lp, t_scale, t_global_scale = data_to_nvfp4_with_global_scale(
|
||||
t,
|
||||
16,
|
||||
)
|
||||
t_hp = from_blocked_format(
|
||||
_floatx_unpacked_to_f32(
|
||||
unpack_uint4(t_lp),
|
||||
FP4_EBITS,
|
||||
FP4_MBITS),
|
||||
t_scale,
|
||||
blocksize=16) * t_global_scale
|
||||
|
||||
return t_hp, t_lp, t_scale, t_global_scale
|
||||
|
||||
def _convert_to_mxfp8_with_hp_ref(t):
|
||||
# Convert a tensor to mxfp8, returning:
|
||||
# t_hp : reconstructed bf16 version of t_lp
|
||||
# t_lp : fp8_e4m3 tensor
|
||||
# t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled)
|
||||
t_scale, t_lp = to_mxfp8(t)
|
||||
t_hp = from_blocked_format(t_lp, t_scale, blocksize=32)
|
||||
|
||||
return t_hp, t_lp, t_scale
|
||||
|
||||
def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'):
|
||||
# Convert scales to blocked format. either mxfp8 or nvfp4
|
||||
th_list = []
|
||||
t_list = []
|
||||
t_blocked_scale_list = []
|
||||
t_global_scale_list = []
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else offs[group_idx - 1]
|
||||
)
|
||||
curr_group_end_offset = offs[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
t_slice = t[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
if format == 'mxfp8':
|
||||
th_slice, tq_slice, t_scale_slice = _convert_to_mxfp8_with_hp_ref(t_slice)
|
||||
elif format == 'nvfp4':
|
||||
th_slice, tq_slice, t_scale_slice, tq_global = _convert_to_nvfp4_with_hp_ref(
|
||||
t_slice,
|
||||
)
|
||||
t_global_scale_list.append(tq_global)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
t_list.append(tq_slice)
|
||||
th_list.append(th_slice)
|
||||
|
||||
# Convert scales to blocked format.
|
||||
t_scale_slice_blocked = to_blocked(
|
||||
t_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
t_blocked_scale_list.append(t_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
tq = torch.cat(t_list, dim=1).contiguous()
|
||||
th = torch.cat(th_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
t_blocked_scales = torch.cat(t_blocked_scale_list, dim=0)
|
||||
MN_rounded = round_up(MN, 128)
|
||||
t_blocked_scales = t_blocked_scales.reshape(MN_rounded, -1)
|
||||
|
||||
# Global scales only exist for nvfp4
|
||||
t_global_scales = None
|
||||
if len(t_global_scale_list) > 0:
|
||||
t_global_scales = torch.stack(t_global_scale_list)
|
||||
|
||||
return th, tq, t_blocked_scales, t_global_scales
|
||||
|
||||
def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format):
|
||||
# Build some standard args that are wordy
|
||||
# Note: if/when ROCm support added, need to change swizzle handling
|
||||
kwargs = {
|
||||
'mxfp8': {
|
||||
'scale_a': scale_a,
|
||||
'scale_b': scale_b,
|
||||
'scale_recipe_a': ScalingType.BlockWise1x32,
|
||||
'scale_recipe_b': ScalingType.BlockWise1x32,
|
||||
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
|
||||
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
|
||||
'offs': offs, # (G,)
|
||||
'out_dtype': torch.bfloat16,
|
||||
'wrap_v2': True,
|
||||
},
|
||||
'nvfp4': {
|
||||
'scale_a': scale_a,
|
||||
'scale_b': scale_b,
|
||||
'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise],
|
||||
'swizzle_a': SwizzleType.SWIZZLE_32_4_4,
|
||||
'swizzle_b': SwizzleType.SWIZZLE_32_4_4,
|
||||
'offs': offs, # (G,)
|
||||
'out_dtype': torch.bfloat16,
|
||||
'wrap_v2': True,
|
||||
},
|
||||
}
|
||||
return kwargs[format]
|
||||
|
||||
class TestFP8Matmul(TestCase):
|
||||
|
||||
@ -526,13 +658,15 @@ class TestFP8Matmul(TestCase):
|
||||
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [2048, 2049])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [16640])
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
|
||||
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
|
||||
def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format):
|
||||
torch.manual_seed(42)
|
||||
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
||||
input_group_end_offsets = generate_jagged_offs(
|
||||
@ -541,95 +675,61 @@ class TestFP8Matmul(TestCase):
|
||||
X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_list = []
|
||||
w_list = []
|
||||
x_blocked_scale_list = []
|
||||
w_blocked_scale_list = []
|
||||
xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
|
||||
X, M, G, input_group_end_offsets, format=format
|
||||
)
|
||||
wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled(
|
||||
W, N, G, input_group_end_offsets, format=format
|
||||
)
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
for group_idx in range(G):
|
||||
# to_mxfp8 per group
|
||||
prev_group_end_offset = (
|
||||
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
|
||||
if format == "mxfp8":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
x_blocked_scales,
|
||||
w_blocked_scales,
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
curr_group_end_offset = input_group_end_offsets[group_idx]
|
||||
group_size = curr_group_end_offset - prev_group_end_offset
|
||||
if group_size > 0:
|
||||
x_slice = X[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (M, K_group)
|
||||
w_slice = W[
|
||||
:, prev_group_end_offset:curr_group_end_offset
|
||||
].contiguous() # (N, K_group)
|
||||
x_scale_slice, xq_slice = to_mxfp8(
|
||||
x_slice
|
||||
) # scale shape -> (M, K_group // 32)
|
||||
w_scale_slice, wq_slice = to_mxfp8(
|
||||
w_slice
|
||||
) # scale shape -> (N, K_group // 32)
|
||||
x_list.append(xq_slice)
|
||||
w_list.append(wq_slice)
|
||||
elif format == "nvfp4":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
[x_blocked_scales, x_global_scales],
|
||||
[w_blocked_scales, w_global_scales],
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Convert scales to blocked format.
|
||||
x_scale_slice_blocked = to_blocked(
|
||||
x_scale_slice
|
||||
) # (round_up(M, 128), round_up(K_group//32, 4))
|
||||
w_scale_slice_blocked = to_blocked(
|
||||
w_scale_slice
|
||||
) # (round_up(N, 128), round_up(K_group//32, 4))
|
||||
x_blocked_scale_list.append(x_scale_slice_blocked)
|
||||
w_blocked_scale_list.append(w_scale_slice_blocked)
|
||||
|
||||
# Assemble the full XQ and WQ
|
||||
xq = torch.cat(x_list, dim=1).contiguous()
|
||||
wq = torch.cat(w_list, dim=1).contiguous()
|
||||
|
||||
# Combine all XQ groups blocked scales into one tensor.
|
||||
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
|
||||
M_rounded = round_up(M, 128)
|
||||
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
|
||||
|
||||
# Combine all WQ groups blocked scales into one tensor.
|
||||
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
|
||||
N_rounded = round_up(N, 128)
|
||||
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
||||
if format == 'nvfp4':
|
||||
assert x_global_scales.numel() == w_global_scales.numel()
|
||||
assert x_global_scales.numel() == G
|
||||
|
||||
# Compute mxfp8 grouped mm output
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
xq, # (M, total_K)
|
||||
wq.transpose(-2, -1), # (total_K, N)
|
||||
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
||||
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
offs=input_group_end_offsets, # (G,)
|
||||
out_dtype=torch.bfloat16,
|
||||
wrap_v2=wrap_v2
|
||||
y_lp = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# bf16 reference output
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
# Note: Reference result should be on reconstructed, not original values.
|
||||
# as-in float(fp4(t)) not t itself.
|
||||
xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Assert no NaNs
|
||||
assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN"
|
||||
assert not y_lp.isnan().any(), "mxfp8 output contains NaN"
|
||||
|
||||
# Assert outputs are close
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@parametrize("G", [1, 4, 16])
|
||||
@parametrize("M", [16640])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [4096])
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
|
||||
@parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else []))
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format):
|
||||
torch.manual_seed(42)
|
||||
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
||||
# 2D inputs with groups along M, 3D weights.
|
||||
@ -643,60 +743,120 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
for i in range(G):
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
w_scale = to_blocked(w_scale)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
def _3d_to_blocked_scaled(W, G, format):
|
||||
wh_list = []
|
||||
wq_list = []
|
||||
w_scale_list = []
|
||||
w_global_scale_list = []
|
||||
for i in range(G):
|
||||
if format == "mxfp8":
|
||||
wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i])
|
||||
elif format == "nvfp4":
|
||||
w_scale, wq = to_mxfp8(W[i])
|
||||
wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i])
|
||||
w_global_scale_list.append(w_global_scale)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Swizzle scaled
|
||||
# TODO(slayton): gate on cuda/hip
|
||||
w_scale = to_blocked(w_scale)
|
||||
|
||||
wh_list.append(wh)
|
||||
wq_list.append(wq)
|
||||
w_scale_list.append(w_scale)
|
||||
wh = torch.stack(wh_list, dim=0).contiguous()
|
||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
||||
# Global scales only exist for nvfp4
|
||||
if len(w_global_scale_list) > 0:
|
||||
w_global_scales = torch.stack(w_global_scale_list)
|
||||
else:
|
||||
w_global_scales = None
|
||||
return wh, wq, w_scale, w_global_scales
|
||||
|
||||
wh, wq, w_blocked_scales, w_global_scales = _3d_to_blocked_scaled(W, G, format)
|
||||
|
||||
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
|
||||
# as they each used for independent gemm in the grouped gemm.
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
x_scale, xq = to_mxfp8(x_slice)
|
||||
x_scale = to_blocked(x_scale)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
def _2d_to_blocked_scaled(X, K, G, offs, format):
|
||||
xh_list = []
|
||||
xq_list = []
|
||||
x_scale_list = []
|
||||
x_global_scale_list = []
|
||||
for i in range(G):
|
||||
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
|
||||
curr_group_end = input_group_end_offsets[i]
|
||||
group_size = curr_group_end - prev_group_end
|
||||
if group_size > 0:
|
||||
x_slice = X[prev_group_end:curr_group_end, :]
|
||||
if format == "mxfp8":
|
||||
xh, xq, x_scale = _convert_to_mxfp8_with_hp_ref(x_slice)
|
||||
elif format == "nvfp4":
|
||||
xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice)
|
||||
x_global_scale_list.append(x_global_scale)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
# Compute mxfp8 grouped gemm.
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
x_scale = to_blocked(x_scale)
|
||||
xh_list.append(xh)
|
||||
xq_list.append(xq)
|
||||
x_scale_list.append(x_scale)
|
||||
xh = torch.cat(xh_list, dim=0).contiguous()
|
||||
xq = torch.cat(xq_list, dim=0).contiguous()
|
||||
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
|
||||
x_scale = x_scale.reshape(-1, K // block_size)
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
xh = xh.view(-1, xh.shape[-1])
|
||||
|
||||
x_global_scales = None
|
||||
if len(x_global_scale_list) > 0:
|
||||
x_global_scales = torch.stack(x_global_scale_list)
|
||||
|
||||
return xh, xq, x_scale, x_global_scales
|
||||
|
||||
xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format)
|
||||
|
||||
if format == "mxfp8":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
x_blocked_scales,
|
||||
w_blocked_scales,
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
elif format == "nvfp4":
|
||||
kwargs = _build_scaled_grouped_mm_kwargs(
|
||||
[x_blocked_scales, x_global_scales],
|
||||
[w_blocked_scales, w_global_scales],
|
||||
input_group_end_offsets,
|
||||
format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"')
|
||||
|
||||
if format == 'nvfp4':
|
||||
assert x_global_scales.numel() == w_global_scales.numel()
|
||||
assert x_global_scales.numel() == G
|
||||
|
||||
# Compute low-precision grouped gemm.
|
||||
y_lp = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
x_scale,
|
||||
w_scale,
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
wrap_v2=wrap_v2)
|
||||
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Compute reference bf16 grouped gemm.
|
||||
# Note: Reference result should be on reconstructed, not original values.
|
||||
# as-in float(fp4(t)) not t itself.
|
||||
y_bf16 = torch._grouped_mm(
|
||||
X,
|
||||
W.transpose(-2, -1),
|
||||
xh,
|
||||
wh.transpose(-2, -1),
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Assert outputs are close.
|
||||
torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@ -1704,6 +1864,7 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("fast_accum", [False, True])
|
||||
# AMD does not support non-contiguous inputs yet
|
||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||
# AMD does not support NVFP4
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
|
||||
device = "cuda"
|
||||
|
||||
@ -1107,6 +1107,7 @@ class TestTransformers(NNTestCase):
|
||||
)[0]
|
||||
|
||||
@tf32_on_and_off(0.003)
|
||||
@parametrize("batch_size", [0, 5])
|
||||
@parametrize("input_dim,attn_mask_dim,is_causal",
|
||||
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
|
||||
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
|
||||
@ -1116,7 +1117,7 @@ class TestTransformers(NNTestCase):
|
||||
if attn_dim is not None else "no_attn_mask")))
|
||||
@parametrize("dropout_p", [0.0, 0.2, 0.5])
|
||||
@sdpa_kernel(backends=[SDPBackend.MATH])
|
||||
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
|
||||
def test_scaled_dot_product_attention(self, device, batch_size, input_dim, attn_mask_dim, is_causal, dropout_p):
|
||||
def sdp_ref(
|
||||
q,
|
||||
k,
|
||||
@ -1140,12 +1141,13 @@ class TestTransformers(NNTestCase):
|
||||
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
|
||||
dtypes = [torch.double, torch.float]
|
||||
for dtype in dtypes:
|
||||
N = batch_size
|
||||
|
||||
def rand_tensor(*shape):
|
||||
return torch.randn(shape, device=device, dtype=dtype)
|
||||
|
||||
# This test compares python and C++ implementations of SDP.
|
||||
N, N_prime, L, S, E = 5, 2, 4, 3, 6
|
||||
N_prime, L, S, E = 2, 4, 3, 6
|
||||
if input_dim == 3:
|
||||
query = rand_tensor(N, L, E)
|
||||
key = rand_tensor(N, S, E)
|
||||
|
||||
@ -5,11 +5,12 @@ from collections import namedtuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention import varlen_attn
|
||||
from torch.nn.attention.varlen import varlen_attn
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
VarlenShape = namedtuple(
|
||||
@ -23,6 +24,18 @@ default_tolerances = {
|
||||
}
|
||||
|
||||
|
||||
class OpLoggingMode(TorchDispatchMode):
|
||||
"""Logging mode that captures all dispatched operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.called_ops = []
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
op_name = str(func)
|
||||
self.called_ops.append(op_name)
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
|
||||
@ -39,12 +52,9 @@ class AttentionBlock(nn.Module):
|
||||
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward_varlen(
|
||||
def get_varlen_qkv(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
qkv = self.qkv_proj(x_packed)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
@ -53,24 +63,51 @@ class AttentionBlock(nn.Module):
|
||||
k = k.view(-1, self.num_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
attn_out = varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
def forward_varlen(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
q, k, v = self.get_varlen_qkv(x_packed)
|
||||
|
||||
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
|
||||
attn_out = attn_out.view(-1, self.embed_dim)
|
||||
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
|
||||
def forward_sdpa(
|
||||
self,
|
||||
x_padded: torch.Tensor,
|
||||
seq_lengths: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
batch_size, seq_len, _ = x_padded.shape
|
||||
|
||||
qkv = self.qkv_proj(x_padded)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
mask = (
|
||||
torch.arange(seq_len, device=x_padded.device)[None, :]
|
||||
< seq_lengths[:, None]
|
||||
)
|
||||
|
||||
attn_mask = mask[:, None, None, :].expand(
|
||||
batch_size, self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=is_causal
|
||||
)
|
||||
|
||||
attn_out = (
|
||||
attn_out.transpose(1, 2)
|
||||
.contiguous()
|
||||
@ -91,7 +128,9 @@ def create_variable_length_batch(
|
||||
seq_lengths = torch.tensor(seq_lengths, device=device)
|
||||
total_tokens = seq_lengths.sum().item()
|
||||
|
||||
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq[1:] = seq_lengths.cumsum(0)
|
||||
@ -106,6 +145,7 @@ def create_variable_length_batch(
|
||||
end_idx = start_idx + seq_len
|
||||
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
x_padded = x_padded.clone().detach().requires_grad_()
|
||||
|
||||
return {
|
||||
"seq_lengths": seq_lengths,
|
||||
@ -133,7 +173,11 @@ class TestVarlenAttention(NNTestCase):
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
@ -147,6 +191,128 @@ class TestVarlenAttention(NNTestCase):
|
||||
self.assertEqual(output.device, torch.device(device))
|
||||
self.assertEqual(output.dtype, dtype)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
self.assertIsNotNone(varlen_grad)
|
||||
self.assertEqual(varlen_grad.shape, x_packed.shape)
|
||||
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_compliance(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
q, k, v = attention_block.get_varlen_qkv(x_packed)
|
||||
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn,
|
||||
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
|
||||
)
|
||||
|
||||
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
|
||||
)
|
||||
grad_out = torch.randn_like(out)
|
||||
|
||||
# we don't support double backward
|
||||
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn_backward,
|
||||
(
|
||||
grad_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
cu_seq,
|
||||
cu_seq,
|
||||
shape.max_seq_len,
|
||||
shape.max_seq_len,
|
||||
False,
|
||||
rng_state,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_registration(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
compiled_forward = torch.compile(
|
||||
attention_block.forward_varlen, backend="eager", fullgraph=True
|
||||
)
|
||||
with OpLoggingMode() as mode:
|
||||
output = compiled_forward(
|
||||
x_packed, cu_seq, shape.max_seq_len, is_causal=False
|
||||
)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
_ = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
called_ops = mode.called_ops
|
||||
|
||||
custom_ops_called = any(
|
||||
"torch_attn._varlen_attn" in op for op in called_ops
|
||||
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
|
||||
assert custom_ops_called
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@ -172,7 +338,10 @@ class TestVarlenAttention(NNTestCase):
|
||||
is_causal=is_causal,
|
||||
)
|
||||
sdpa_output = attention_block.forward_sdpa(
|
||||
variable_length_batch_data["x_padded"], is_causal=is_causal
|
||||
variable_length_batch_data["x_padded"],
|
||||
variable_length_batch_data["seq_lengths"],
|
||||
dtype=dtype,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
tolerances = default_tolerances[dtype]
|
||||
@ -186,6 +355,44 @@ class TestVarlenAttention(NNTestCase):
|
||||
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad_out = torch.ones_like(varlen_output)
|
||||
|
||||
sdpa_grad_out = torch.zeros_like(sdpa_output)
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=varlen_output,
|
||||
inputs=variable_length_batch_data["x_packed"],
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
sdpa_grad = torch.autograd.grad(
|
||||
outputs=sdpa_output,
|
||||
inputs=variable_length_batch_data["x_padded"],
|
||||
grad_outputs=sdpa_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
|
||||
varlen_grad_seq = varlen_grad[start_idx:end_idx]
|
||||
sdpa_grad_seq = sdpa_grad[i, :seq_len]
|
||||
|
||||
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
|
||||
device_types = ("cuda",)
|
||||
|
||||
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: a6b2477b88...6fcbc53d33
@ -445,7 +445,7 @@ use_numpy_random_stream = False
|
||||
enable_cpp_guard_manager = True
|
||||
|
||||
# Use C++ guard manager for symbolic shapes
|
||||
enable_cpp_symbolic_shape_guards = not is_fbcode()
|
||||
enable_cpp_symbolic_shape_guards = False
|
||||
|
||||
# Enable tracing through contextlib.contextmanager
|
||||
enable_trace_contextlib = True
|
||||
|
||||
@ -2820,5 +2820,36 @@
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0280": [
|
||||
{
|
||||
"Gb_type": "1-arg super not implemented",
|
||||
"Context": "",
|
||||
"Explanation": "Dynamo failed to trace attribute `{name}` accessed via `super()` (for type `{self.typevar}` and object `{self.objvar}`) because one-argument of super() is not supported.",
|
||||
"Hints": [
|
||||
"Use two-argument super(type, object_or_type)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0281": [
|
||||
{
|
||||
"Gb_type": "Invalid or non-const argument in nn.Module __getitem__",
|
||||
"Context": "call_method: {self} {name} {args} {kwargs}",
|
||||
"Explanation": "Dynamo does not support calling method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
|
||||
"Hints": [
|
||||
"Use constant arguments of type str or int for __getitem__"
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0282": [
|
||||
{
|
||||
"Gb_type": "Placement with custom __getattr__ not supported",
|
||||
"Context": "{value_type.__name__} with custom __getattr__",
|
||||
"Explanation": "Dynamo does not support Placement types with custom __getattr__ methods",
|
||||
"Hints": [
|
||||
"Use Placement types without custom __getattr__ methods",
|
||||
"Move the Placement usage outside the compiled region"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -2299,10 +2299,13 @@ class GuardBuilder(GuardBuilderBase):
|
||||
],
|
||||
)
|
||||
|
||||
def FUNCTION_MATCH(self, guard: Guard) -> None:
|
||||
"""things like torch.add and user defined functions"""
|
||||
# don't support this in serialization because it uses unsupported ID_MATCH
|
||||
return self.ID_MATCH(guard)
|
||||
def UNCLASSIFIED_ID_MATCH(self, guard: Guard) -> None:
|
||||
"""
|
||||
Calls id_match guard but also helps with future debugging where we are
|
||||
calling ID_MATCH on an object that we don't understand why. This will
|
||||
show up in tlparse.
|
||||
"""
|
||||
self.id_match_unchecked(guard)
|
||||
|
||||
def CLASS_MATCH(self, guard: Guard) -> None:
|
||||
"""Equals ID_MATCH on classes - better readability than directly calling ID_MATCH"""
|
||||
@ -2324,14 +2327,13 @@ class GuardBuilder(GuardBuilderBase):
|
||||
|
||||
def CLOSURE_MATCH(self, guard: Guard) -> None:
|
||||
"""matches a closure by __code__ id."""
|
||||
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
|
||||
val = self.get(guard.name)
|
||||
# Strictly only want user-defined functions
|
||||
if type(val) is types.FunctionType and hasattr(val, "__code__"):
|
||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
|
||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
|
||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type]
|
||||
else:
|
||||
self.FUNCTION_MATCH(guard)
|
||||
self.UNCLASSIFIED_ID_MATCH(guard)
|
||||
|
||||
def BUILTIN_MATCH(self, guard: Guard) -> None:
|
||||
if self.save_guards:
|
||||
@ -3718,11 +3720,11 @@ class CheckFunctionManager:
|
||||
"DICT_VERSION",
|
||||
"NN_MODULE",
|
||||
"ID_MATCH",
|
||||
"FUNCTION_MATCH",
|
||||
"CLASS_MATCH",
|
||||
"MODULE_MATCH",
|
||||
"CLOSURE_MATCH",
|
||||
"WEAKREF_ALIVE",
|
||||
"UNCLASSIFIED_ID_MATCH",
|
||||
)
|
||||
|
||||
def serialize_guards(
|
||||
|
||||
@ -629,7 +629,7 @@ class VariableBuilder:
|
||||
lambda self, value: LambdaVariable(
|
||||
_dataclasses_fields_lambda,
|
||||
source=self.source,
|
||||
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
|
||||
**self.install_guards(GuardBuilder.CLOSURE_MATCH),
|
||||
),
|
||||
),
|
||||
(torch.__version__, lambda self, value: TorchVersionVariable()),
|
||||
@ -927,8 +927,10 @@ class VariableBuilder:
|
||||
)
|
||||
elif inspect.isclass(value):
|
||||
self.install_guards(GuardBuilder.CLASS_MATCH)
|
||||
elif inspect.isfunction(value):
|
||||
self.install_guards(GuardBuilder.CLOSURE_MATCH)
|
||||
elif callable(value):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
else:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return NumpyVariable(value, source=self.source)
|
||||
@ -945,7 +947,7 @@ class VariableBuilder:
|
||||
return NumpyTypeInfoVariable(value, source=self.source)
|
||||
# NB: These can't be put in type_dispatch, they have to run later
|
||||
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
self.install_guards(GuardBuilder.CLOSURE_MATCH)
|
||||
return CollectiveFunctionRewriteVariable.create(
|
||||
self.tx,
|
||||
value,
|
||||
@ -1371,7 +1373,7 @@ class VariableBuilder:
|
||||
elif isinstance(value, types.MethodWrapperType):
|
||||
# Method-wrappers are written in C, and they are not guaranteed to
|
||||
# return the same object on attribute lookup. Therefore, we cannot
|
||||
# insert a FUNCTION_MATCH guard here. method-wrappers are very
|
||||
# insert a ID_MATCH guard here. method-wrappers are very
|
||||
# unlikely to change, so its ok to skip the guard here.
|
||||
return MethodWrapperVariable(value)
|
||||
elif issubclass(type(value), type) and issubclass(value, BaseException):
|
||||
|
||||
@ -24,10 +24,12 @@ import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import ExitStack
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._dynamo.variables.misc import GetAttrVariable
|
||||
from torch._guards import Guard
|
||||
from torch.fx import Proxy
|
||||
|
||||
from .. import graph_break_hints, variables
|
||||
from ..bytecode_transformation import (
|
||||
@ -41,6 +43,7 @@ from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalStateSource
|
||||
from ..utils import _get_error_on_graph_break, _set_error_on_graph_break
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .functions import (
|
||||
NestedUserFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
@ -992,13 +995,82 @@ class ProfilerContextVariable(ContextWrappingVariable):
|
||||
|
||||
|
||||
class StreamContextVariable(ContextWrappingVariable):
|
||||
"""This represents torch.cuda.StreamContext"""
|
||||
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", target_value, **kwargs):
|
||||
def create(
|
||||
tx: "InstructionTranslator",
|
||||
target_value: "StreamVariable",
|
||||
**kwargs: dict[str, Any],
|
||||
) -> "StreamContextVariable":
|
||||
return StreamContextVariable(
|
||||
target_values=[target_value],
|
||||
initial_values=[
|
||||
StreamContextVariable._get_current_stream(target_value.device, tx)
|
||||
],
|
||||
device=target_value.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_values: list["StreamVariable"],
|
||||
device: torch.device,
|
||||
initial_values: Optional[list["StreamVariable"]] = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.device = device
|
||||
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
|
||||
# to stream, from stream is the order of the arguments
|
||||
# we are entering the target, and leaving the initial stream
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.ops.streams.fork.default,
|
||||
self._target_stream_proxies() + self._initial_stream_proxies(),
|
||||
{},
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
|
||||
# to stream, from stream is the order of the arguments
|
||||
# we are leaving the target, and entering the initial stream
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.ops.streams.join.default,
|
||||
self._initial_stream_proxies() + self._target_stream_proxies(),
|
||||
{},
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
|
||||
assert self.initial_values, "No initial stream to move from"
|
||||
return StreamContextVariable._extract_stream_properties(
|
||||
self.initial_values[0].as_proxy()
|
||||
)
|
||||
|
||||
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
|
||||
return StreamContextVariable._extract_stream_properties(
|
||||
self.target_values[0].as_proxy()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]:
|
||||
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
|
||||
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
|
||||
return stream_index, stream_device
|
||||
|
||||
@staticmethod
|
||||
def _get_current_stream(
|
||||
device: torch.device, tx: "InstructionTranslator"
|
||||
) -> "StreamVariable":
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
current_stream_method = get_interface_for_device(
|
||||
target_value.device
|
||||
).current_stream
|
||||
current_stream_method = get_interface_for_device(device).current_stream
|
||||
current_stream = wrap_fx_proxy_cls(
|
||||
StreamVariable,
|
||||
tx,
|
||||
@ -1009,50 +1081,7 @@ class StreamContextVariable(ContextWrappingVariable):
|
||||
{},
|
||||
),
|
||||
)
|
||||
return StreamContextVariable(
|
||||
target_values=[target_value],
|
||||
initial_values=[current_stream],
|
||||
device=target_value.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
self.device = device
|
||||
self.set_stream = get_interface_for_device(self.device).set_stream
|
||||
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
|
||||
|
||||
def enter(self, tx):
|
||||
# stream generated inside the traced function
|
||||
if self.target_values[0].as_proxy() is not None:
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.set_stream,
|
||||
(self.target_values[0].as_proxy(),),
|
||||
{},
|
||||
)
|
||||
# stream passed from outside the traced function
|
||||
else:
|
||||
stream = self.target_values[0].value
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.set_stream_id,
|
||||
(stream.stream_id, stream.device_index, stream.device_type),
|
||||
{},
|
||||
)
|
||||
self.set_stream(self.target_values[0].value)
|
||||
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.set_stream,
|
||||
(self.initial_values[0].as_proxy(),),
|
||||
{},
|
||||
)
|
||||
self.cleanup_assert()
|
||||
return current_stream
|
||||
|
||||
|
||||
class PreserveVersionContextVariable(ContextWrappingVariable):
|
||||
|
||||
@ -210,9 +210,16 @@ class PlacementVariable(DistributedVariable):
|
||||
if name in constant_fold_functions:
|
||||
try:
|
||||
value_type = type(self.value)
|
||||
assert (
|
||||
inspect.getattr_static(value_type, "__getattr__", None) is None
|
||||
), "no custom getattr allowed!"
|
||||
if inspect.getattr_static(value_type, "__getattr__", None) is not None:
|
||||
unimplemented_v2(
|
||||
gb_type="Placement with custom __getattr__ not supported",
|
||||
context=f"{value_type.__name__} with custom __getattr__",
|
||||
explanation="Dynamo does not support Placement types with custom __getattr__ methods",
|
||||
hints=[
|
||||
"Use Placement types without custom __getattr__ methods",
|
||||
"Move the Placement usage outside the compiled region",
|
||||
],
|
||||
)
|
||||
method = inspect.getattr_static(value_type, name)
|
||||
except AttributeError:
|
||||
method = None
|
||||
|
||||
@ -2001,7 +2001,7 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
|
||||
|
||||
return cls(value, source=source)
|
||||
|
||||
|
||||
@ -103,7 +103,17 @@ class SuperVariable(VariableTracker):
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
|
||||
assert self.objvar, "1-arg super not implemented"
|
||||
if not self.objvar:
|
||||
unimplemented_v2(
|
||||
gb_type="1-arg super not implemented",
|
||||
context="",
|
||||
explanation=f"Dynamo failed to trace attribute `{name}` accessed "
|
||||
f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
|
||||
"because one-argument of super() is not supported.",
|
||||
hints=[
|
||||
"Use two-argument super(type, object_or_type).",
|
||||
],
|
||||
)
|
||||
search_type = self.typevar.as_python_constant()
|
||||
|
||||
# The rest of this function does two things:
|
||||
@ -1032,10 +1042,10 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
||||
assert tx.one_graph or tx.error_on_graph_break, (
|
||||
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||
)
|
||||
# queue_callback is a method-wrapper, no need to insert a guard.
|
||||
fn_vt = VariableTracker.build(
|
||||
tx,
|
||||
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
|
||||
source=self.source,
|
||||
)
|
||||
return fn_vt.call_function(
|
||||
tx,
|
||||
|
||||
@ -822,9 +822,19 @@ class NNModuleVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if type(module).__getitem__ not in builtin_supported:
|
||||
assert isinstance(args[0], variables.ConstantVariable), typestr(args[0])
|
||||
key = args[0].as_python_constant()
|
||||
assert isinstance(key, (str, int))
|
||||
if not (
|
||||
isinstance(args[0], variables.ConstantVariable)
|
||||
and isinstance(args[0].as_python_constant(), (str, int))
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Invalid or non-const argument in nn.Module __getitem__",
|
||||
context=f"call_method: {self} {name} {args} {kwargs}",
|
||||
explanation="Dynamo does not support calling "
|
||||
f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
|
||||
hints=[
|
||||
"Use constant arguments of type str or int for __getitem__"
|
||||
],
|
||||
)
|
||||
fn = getattr(module, name).__func__
|
||||
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
@ -262,7 +262,9 @@ class BaseTorchVariable(VariableTracker):
|
||||
# Dont need to guard on wrappers
|
||||
pass
|
||||
else:
|
||||
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
# Installing an ID_MATCH to preserve the old behavior. But making it
|
||||
# unclassified so that we can eventually remove it.
|
||||
install_guard(source.make_guard(GuardBuilder.UNCLASSIFIED_ID_MATCH))
|
||||
return cls(value, source=source)
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
|
||||
@ -720,13 +720,22 @@ def check_shape(
|
||||
) -> None:
|
||||
backend = get_current_backend()
|
||||
assert shape is not None
|
||||
if config.test_configs.runtime_triton_dtype_assert and backend == "triton":
|
||||
if config.test_configs.runtime_triton_shape_assert and backend == "triton":
|
||||
shape_str = (
|
||||
", ".join(str(d) for d in shape) if len(shape) != 1 else f"{shape[0]},"
|
||||
)
|
||||
buffer.writeline(f"tl.static_assert({var}.shape == ({shape_str}))")
|
||||
|
||||
|
||||
def check_nan(buffer: IndentedBuffer, var: CSEVariableType) -> None:
|
||||
backend = get_current_backend()
|
||||
if backend == "triton":
|
||||
msg = "NaN or Inf found"
|
||||
buffer.writeline(
|
||||
f"tl.device_assert(({var} == {var}) & ({var} != float('inf')) & ({var} != float('-inf')), '{msg}')"
|
||||
)
|
||||
|
||||
|
||||
class DataTypePropagation:
|
||||
def __init__(self, body: LoopBody) -> None:
|
||||
self.body = body
|
||||
@ -2623,6 +2632,9 @@ class CSEProxy(DefaultHandler):
|
||||
assert output_shape is not None
|
||||
check_shape(V.kernel.compute, csevar, output_shape)
|
||||
|
||||
if config.runtime_triton_nan_asserts:
|
||||
check_nan(V.kernel.compute, csevar)
|
||||
|
||||
return csevar
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
|
||||
@ -626,7 +626,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
filename=__file__,
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -206,6 +206,9 @@ static_weight_shapes = True
|
||||
# put correctness assertions in generated code
|
||||
size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
|
||||
nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
|
||||
runtime_triton_nan_asserts = (
|
||||
os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
|
||||
)
|
||||
scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
|
||||
|
||||
# Disable by default in fbcode
|
||||
|
||||
@ -3550,24 +3550,13 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not inductor_meta.get("autotune_pointwise", True) and not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -8,6 +8,7 @@ from contextlib import AbstractContextManager
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
@ -512,6 +513,30 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# implement matrix related ops for distributed tensor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Optional
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
@ -19,8 +15,8 @@ from torch.distributed.tensor._ops.utils import (
|
||||
register_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
MaskPartial,
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
@ -29,190 +25,6 @@ from torch.distributed.tensor.placement_types import (
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskBuffer:
|
||||
data: Optional[torch.Tensor] = None
|
||||
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
|
||||
refcount: int = 0
|
||||
|
||||
def materialize_mask(self, mask):
|
||||
if self.refcount == 0:
|
||||
self.data = mask
|
||||
else:
|
||||
assert self.data is not None
|
||||
if not torch.equal(self.data, mask):
|
||||
raise RuntimeError(
|
||||
"MaskBuffer has been materialized with conflicting data"
|
||||
)
|
||||
self.refcount += 1
|
||||
|
||||
def release_mask(self):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
self.refcount -= 1
|
||||
if self.refcount == 0:
|
||||
self.data = None
|
||||
|
||||
def apply_mask(self, tensor):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
|
||||
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
|
||||
# For gather, the mask has the same dimension as the output tensor, whereas
|
||||
# the output of the embedding op has an additional dimension compare to the input,
|
||||
# hence the output masking logic below having two different cases.
|
||||
if tensor.ndim == self.data.ndim:
|
||||
tensor[self.data] = 0.0
|
||||
else:
|
||||
tensor[self.data, :] = 0.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _MaskPartial(Partial):
|
||||
"""
|
||||
A partial mask placement devised for rowwise sharded embedding op, where we need
|
||||
to mask and adjust the indices to the local embedding shard, embedding masking
|
||||
is a special type of the Partial placement
|
||||
|
||||
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
|
||||
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
|
||||
"""
|
||||
|
||||
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
|
||||
|
||||
# required fields for computing the local offset and deriving the mask
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_op=None,
|
||||
mask_buffer=None,
|
||||
offset_shape=None,
|
||||
offset_dim=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(reduce_op)
|
||||
if mask_buffer is None:
|
||||
mask_buffer = MaskBuffer()
|
||||
object.__setattr__(self, "mask_buffer", mask_buffer)
|
||||
object.__setattr__(self, "offset_shape", offset_shape)
|
||||
object.__setattr__(self, "offset_dim", offset_dim)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _mask_tensor(
|
||||
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
mask = (tensor < local_offset_on_dim) | (
|
||||
tensor >= local_offset_on_dim + local_shard_size
|
||||
)
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
return mask, masked_tensor
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
my_coordinate = mesh.get_coordinate()
|
||||
assert my_coordinate is not None, "my_coordinate should not be None"
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for _MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
my_coordinate[mesh_dim],
|
||||
)
|
||||
mask, masked_tensor = _MaskPartial._mask_tensor(
|
||||
tensor, local_offset_on_dim, local_shard_size
|
||||
)
|
||||
# materialize the mask buffer to be used for reduction
|
||||
self.mask_buffer.materialize_mask(mask)
|
||||
return masked_tensor
|
||||
|
||||
def _reduce_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# perform sum reduction
|
||||
return funcol.all_reduce(
|
||||
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
|
||||
)
|
||||
|
||||
def _reduce_shard_value(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
shard_spec: Placement,
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# call reduce_shard_tensor of the shard_spec.
|
||||
shard_spec = cast(Shard, shard_spec)
|
||||
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _MaskPartial):
|
||||
return False
|
||||
|
||||
# if either data is not None, we invalidate the sharding cache, as this indicates
|
||||
# the current MaskPartial placement is still in use and should not be used for cache hit.
|
||||
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
|
||||
return False
|
||||
|
||||
return (
|
||||
self.reduce_op == other.reduce_op
|
||||
and self.offset_shape == other.offset_shape
|
||||
and self.offset_dim == other.offset_dim
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1 + hash(
|
||||
(
|
||||
self.reduce_op,
|
||||
self.offset_shape,
|
||||
self.offset_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the MaskPartial placement
|
||||
"""
|
||||
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
human readable representation of the MaskPartial placement
|
||||
"""
|
||||
return "MaskP"
|
||||
|
||||
|
||||
@register_op_strategy(aten.embedding.default)
|
||||
def embedding_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
"""
|
||||
@ -239,7 +51,7 @@ def embedding_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
single_mesh_dim_strategies.append(colwise_sharding)
|
||||
|
||||
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
|
||||
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
|
||||
embedding_partial_placement = MaskPartial(offset_shape=weight_shape, offset_dim=0)
|
||||
|
||||
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
|
||||
# from the input indices and use it for output reduction
|
||||
|
||||
44
torch/distributed/tensor/_ops/_mask_buffer.py
Normal file
44
torch/distributed/tensor/_ops/_mask_buffer.py
Normal file
@ -0,0 +1,44 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskBuffer:
|
||||
data: Optional[torch.Tensor] = None
|
||||
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
|
||||
refcount: int = 0
|
||||
|
||||
def materialize_mask(self, mask):
|
||||
if self.refcount == 0:
|
||||
self.data = mask
|
||||
else:
|
||||
assert self.data is not None
|
||||
if not torch.equal(self.data, mask):
|
||||
raise RuntimeError(
|
||||
"MaskBuffer has been materialized with conflicting data"
|
||||
)
|
||||
self.refcount += 1
|
||||
|
||||
def release_mask(self):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
self.refcount -= 1
|
||||
if self.refcount == 0:
|
||||
self.data = None
|
||||
|
||||
def apply_mask(self, tensor):
|
||||
if self.refcount == 0 or self.data is None:
|
||||
raise RuntimeError("MaskBuffer has not been materialized")
|
||||
|
||||
# NOTE: MaskPartial is being used by the embedding op and the gather op.
|
||||
# For gather, the mask has the same dimension as the output tensor, whereas
|
||||
# the output of the embedding op has an additional dimension compare to the input,
|
||||
# hence the output masking logic below having two different cases.
|
||||
if tensor.ndim == self.data.ndim:
|
||||
tensor[self.data] = 0.0
|
||||
else:
|
||||
tensor[self.data, :] = 0.0
|
||||
@ -17,7 +17,7 @@ from torch.distributed.tensor._op_schema import (
|
||||
TupleStrategy,
|
||||
)
|
||||
from torch.distributed.tensor._ops._common_rules import pointwise_rule
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
|
||||
from torch.distributed.tensor._ops.utils import (
|
||||
expand_to_full_mesh_op_strategy,
|
||||
generate_redistribute_costs,
|
||||
@ -646,7 +646,7 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
# this only works when the input is sharded on the gather dimension, and
|
||||
# index has size 1 on the gather dimension
|
||||
if dim < len(index_shape) and index_shape[dim] == 1:
|
||||
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
|
||||
index_partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=dim)
|
||||
input_sharding: PlacementList = [
|
||||
index_partial_placement,
|
||||
Shard(dim),
|
||||
|
||||
@ -11,7 +11,7 @@ from torch import Tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
from torch.distributed.tensor._ops._embedding_ops import MaskPartial
|
||||
from torch.distributed.tensor._ops._math_ops import (
|
||||
_skip_dim,
|
||||
Reduction,
|
||||
@ -236,7 +236,7 @@ def _nll_loss_forward(
|
||||
|
||||
# The following code block is a distributed version of
|
||||
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
safe_target_partial_ = partial_placement._partition_value(
|
||||
safe_target_, mesh, mesh_dim
|
||||
)
|
||||
@ -375,7 +375,7 @@ def _nll_loss_and_log_softmax_backward(
|
||||
|
||||
# The following code block is a distributed version of
|
||||
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
|
||||
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
|
||||
safe_target = safe_target.squeeze(channel_dim).flatten()
|
||||
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
|
||||
# only update grad_input to -1 if not masked
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
@ -17,9 +18,10 @@ from torch.distributed.tensor._collective_utils import (
|
||||
shard_dim_alltoall,
|
||||
unpad_tensor,
|
||||
)
|
||||
from torch.distributed.tensor._ops._mask_buffer import MaskBuffer
|
||||
|
||||
|
||||
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
|
||||
__all__ = ["Placement", "Shard", "Replicate", "Partial", "MaskPartial"]
|
||||
|
||||
|
||||
# Appease TestPublicBindings.test_correct_module_names
|
||||
@ -841,3 +843,149 @@ class Partial(torch._C._distributed.Partial):
|
||||
|
||||
# We keep the old _Partial name for a while for BC reason
|
||||
_Partial = Partial
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MaskPartial(Partial):
|
||||
"""
|
||||
A partial mask placement devised for rowwise sharded embedding op, where we need
|
||||
to mask and adjust the indices to the local embedding shard, embedding masking
|
||||
is a special type of the Partial placement
|
||||
|
||||
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
|
||||
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
|
||||
"""
|
||||
|
||||
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
|
||||
|
||||
# required fields for computing the local offset and deriving the mask
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_op=None,
|
||||
mask_buffer=None,
|
||||
offset_shape=None,
|
||||
offset_dim=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(reduce_op)
|
||||
if mask_buffer is None:
|
||||
mask_buffer = MaskBuffer()
|
||||
object.__setattr__(self, "mask_buffer", mask_buffer)
|
||||
object.__setattr__(self, "offset_shape", offset_shape)
|
||||
object.__setattr__(self, "offset_dim", offset_dim)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _mask_tensor(
|
||||
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
mask = (tensor < local_offset_on_dim) | (
|
||||
tensor >= local_offset_on_dim + local_shard_size
|
||||
)
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
return mask, masked_tensor
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
my_coordinate = mesh.get_coordinate()
|
||||
assert my_coordinate is not None, "my_coordinate should not be None"
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
my_coordinate[mesh_dim],
|
||||
)
|
||||
mask, masked_tensor = MaskPartial._mask_tensor(
|
||||
tensor, local_offset_on_dim, local_shard_size
|
||||
)
|
||||
# materialize the mask buffer to be used for reduction
|
||||
self.mask_buffer.materialize_mask(mask)
|
||||
return masked_tensor
|
||||
|
||||
def _reduce_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# perform sum reduction
|
||||
return funcol.all_reduce(
|
||||
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
|
||||
)
|
||||
|
||||
def _reduce_shard_value(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
shard_spec: Placement,
|
||||
) -> torch.Tensor:
|
||||
# by the time we need reduction, we should have already saved the mask
|
||||
assert self.mask_buffer.data is not None
|
||||
|
||||
# apply the mask to the tensor that pending reduction
|
||||
self.mask_buffer.apply_mask(tensor)
|
||||
|
||||
# clear the mask buffer
|
||||
self.mask_buffer.release_mask()
|
||||
|
||||
# call reduce_shard_tensor of the shard_spec.
|
||||
shard_spec = cast(Shard, shard_spec)
|
||||
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, MaskPartial):
|
||||
return False
|
||||
|
||||
# if either data is not None, we invalidate the sharding cache, as this indicates
|
||||
# the current MaskPartial placement is still in use and should not be used for cache hit.
|
||||
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
|
||||
return False
|
||||
|
||||
return (
|
||||
self.reduce_op == other.reduce_op
|
||||
and self.offset_shape == other.offset_shape
|
||||
and self.offset_dim == other.offset_dim
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1 + hash(
|
||||
(
|
||||
self.reduce_op,
|
||||
self.offset_shape,
|
||||
self.offset_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
machine readable representation of the MaskPartial placement
|
||||
"""
|
||||
return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
human readable representation of the MaskPartial placement
|
||||
"""
|
||||
return "MaskP"
|
||||
|
||||
@ -206,6 +206,21 @@ class TracerBase:
|
||||
if current_meta.get("in_grad_fn", 0) > 0:
|
||||
annotation_log.debug("seq_nr from current_meta")
|
||||
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
|
||||
|
||||
# See Note [Functionalization View Replay Annotation]
|
||||
# Overriding some node meta with the original node meta of the
|
||||
# regenerated node.
|
||||
replay_node: Node = fx_traceback.get_current_replay_node()
|
||||
if replay_node is not None:
|
||||
node.meta["is_functional_regenerated"] = True
|
||||
if "seq_nr" in replay_node.meta:
|
||||
annotation_log.debug("seq_nr from replay_node")
|
||||
new_seq_nr = replay_node.meta["seq_nr"]
|
||||
if "custom" in replay_node.meta:
|
||||
node.meta["custom"] = replay_node.meta.get("custom")
|
||||
if "stack_trace" in replay_node.meta:
|
||||
node.stack_trace = replay_node.meta.get("stack_trace")
|
||||
|
||||
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
|
||||
node.meta["seq_nr"] = new_seq_nr
|
||||
|
||||
|
||||
@ -30,9 +30,12 @@ __all__ = [
|
||||
"NodeSource",
|
||||
"NodeSourceAction",
|
||||
"get_graph_provenance_json",
|
||||
"set_current_replay_node",
|
||||
"get_current_replay_node",
|
||||
]
|
||||
|
||||
current_meta: dict[str, Any] = {}
|
||||
current_replay_node: Optional[Node] = None
|
||||
should_preserve_node_meta = False
|
||||
|
||||
|
||||
@ -400,6 +403,31 @@ def get_current_meta() -> dict[str, Any]:
|
||||
return current_meta
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
def set_current_replay_node(node):
|
||||
"""
|
||||
Set the currently replay node. If `current_replay_node` is not None,
|
||||
then we're re-generating the `current_replay_node` in FunctionalTensorMode.
|
||||
"""
|
||||
# See [Note] annotation for more details.
|
||||
global current_replay_node
|
||||
saved_current_replay_node = current_replay_node
|
||||
try:
|
||||
current_replay_node = node
|
||||
yield
|
||||
finally:
|
||||
current_replay_node = saved_current_replay_node
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_current_replay_node():
|
||||
"""
|
||||
Get the currently replay node
|
||||
"""
|
||||
return current_replay_node
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -14,14 +14,11 @@ from torch.backends.cuda import (
|
||||
SDPAParams,
|
||||
)
|
||||
|
||||
from .varlen import varlen_attn
|
||||
|
||||
|
||||
__all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"varlen_attn",
|
||||
]
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
|
||||
@ -34,7 +34,7 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
)
|
||||
from torch.nn.attention._utils import _validate_sdpa_input
|
||||
from torch.utils._pytree import GetAttrKey, register_pytree_node, tree_map_only
|
||||
from torch.utils._pytree import GetAttrKey, tree_map_only
|
||||
|
||||
|
||||
# Private debug flag to disable internal compilation wrapping for debugging purposes.
|
||||
@ -1648,12 +1648,3 @@ def flex_attention(
|
||||
return _finalize_outputs(
|
||||
out, lse, max_scores, return_aux=return_aux, return_lse=return_lse
|
||||
)
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
BlockMask,
|
||||
BlockMask._flatten,
|
||||
BlockMask._unflatten,
|
||||
flatten_with_keys_fn=BlockMask._flatten_with_keys,
|
||||
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import NamedTuple, Optional, Union
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -33,8 +33,7 @@ class AuxRequest(NamedTuple):
|
||||
lse: bool = False
|
||||
|
||||
|
||||
# import failures when I try to register as custom op
|
||||
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
|
||||
def _varlen_attn(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -44,7 +43,7 @@ def _varlen_attn(
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Private custom op for variable-length attention.
|
||||
|
||||
@ -70,7 +69,7 @@ def _varlen_attn(
|
||||
False, # return_debug_mask
|
||||
)
|
||||
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
|
||||
output, softmax_lse = result[0], result[1]
|
||||
output, softmax_lse, rng_state = result[0], result[1], result[6]
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
|
||||
@ -86,10 +85,13 @@ def _varlen_attn(
|
||||
return_debug_mask=False,
|
||||
)
|
||||
|
||||
return output, softmax_lse
|
||||
rng_state_ = torch.zeros(
|
||||
(2,), dtype=torch.uint64, device=query.device
|
||||
) # hardcoded since dropout is hardcoded to 0
|
||||
return output, softmax_lse, rng_state_
|
||||
|
||||
|
||||
# @_varlen_attn.register_fake
|
||||
@_varlen_attn.register_fake
|
||||
def _varlen_attn_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -99,7 +101,7 @@ def _varlen_attn_fake(
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
|
||||
@ -117,7 +119,9 @@ def _varlen_attn_fake(
|
||||
(num_heads, total_q), dtype=torch.float, device=query.device
|
||||
)
|
||||
|
||||
return output, logsumexp
|
||||
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
|
||||
|
||||
return output, logsumexp, rng_state
|
||||
|
||||
|
||||
def varlen_attn(
|
||||
@ -191,9 +195,145 @@ def varlen_attn(
|
||||
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
|
||||
... )
|
||||
"""
|
||||
out, lse = _varlen_attn(
|
||||
out, lse, _ = torch.ops.torch_attn._varlen_attn(
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
|
||||
)
|
||||
if return_aux is not None and return_aux.lse:
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
|
||||
out, lse, rng_state = output
|
||||
ctx.query = query
|
||||
ctx.key = key
|
||||
ctx.value = value
|
||||
ctx.cu_seq_q = cu_seq_q
|
||||
ctx.cu_seq_k = cu_seq_k
|
||||
ctx.max_q = max_q
|
||||
ctx.max_k = max_k
|
||||
ctx.is_causal = is_causal
|
||||
ctx.output = out
|
||||
ctx.lse = lse
|
||||
ctx.rng_state = rng_state
|
||||
|
||||
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
|
||||
def _varlen_attn_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
unused = torch.empty(0, device=query.device)
|
||||
|
||||
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
|
||||
if use_cudnn:
|
||||
log.info("Using cuDNN backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._flash_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
@_varlen_attn_backward.register_fake
|
||||
def _varlen_attn_backward_fake(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
"""
|
||||
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _backward(
|
||||
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
query = ctx.query
|
||||
key = ctx.key
|
||||
value = ctx.value
|
||||
cu_seq_q = ctx.cu_seq_q
|
||||
cu_seq_k = ctx.cu_seq_k
|
||||
max_q = ctx.max_q
|
||||
max_k = ctx.max_k
|
||||
is_causal = ctx.is_causal
|
||||
out = ctx.output
|
||||
lse = ctx.lse
|
||||
rng_state = ctx.rng_state
|
||||
|
||||
# rng_state = torch.empty(2, device=query.device)
|
||||
|
||||
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
is_causal,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv, None, None, None, None, None, None
|
||||
|
||||
|
||||
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)
|
||||
|
||||
@ -43,7 +43,6 @@ from torch.testing._internal.common_utils import (
|
||||
_TestParametrizer,
|
||||
skipIfMPS,
|
||||
skipIfTorchDynamo,
|
||||
skipIfXpu,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
)
|
||||
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
|
||||
@ -2201,9 +2200,6 @@ optim_db: list[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
device_type="mps",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
|
||||
@ -447,6 +447,56 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
# NVIDIA Blackwell HW requires scales for MX/NV blocked formats to be in a 128x4 tile layout,
|
||||
# with a weird 32x4x4 internal layout of that tile. If we want to take swizzled scales and use them
|
||||
# for non-gemm purposes (like testing), we need to de-swizzle them, then they can be applied much
|
||||
# more naturally.
|
||||
def from_blocked(input, input_scales, blocksize) -> torch.Tensor:
|
||||
# Matrix is in a 128x4 pattern, internally blocked as 32x4x4 nonsense.
|
||||
# Output should be [input.size(0, input.size(1) // blocksize] scales
|
||||
output_scales = torch.zeros(
|
||||
(input.size(0), input.size(1) // blocksize),
|
||||
device=input.device,
|
||||
dtype=input_scales.dtype,
|
||||
)
|
||||
|
||||
# Swizzled scales are padded to tiles of 128x4, we need to replicate how that padding
|
||||
# happened for offset purposes.
|
||||
# There are K//blocksize scales, padded to groups of 4.
|
||||
num_col_tiles = ceil_div(ceil_div(input.size(1), blocksize), 4)
|
||||
|
||||
# (Very) slow reference implementation using horrifying loops.
|
||||
for i in range(input.size(0)):
|
||||
for j in range(input.size(1) // blocksize):
|
||||
# which 128x4 tile of scaling factors am I in
|
||||
scale_tile_h = i // 128
|
||||
scale_tile_w = j // 4
|
||||
|
||||
# There are (padded) input_scales.size(1) // 4 tiles along the w dim.
|
||||
# So offset is 512 * (h_tile * tiles_per_row + tile_in_row)
|
||||
tile_offset = 512 * (scale_tile_h * num_col_tiles + scale_tile_w)
|
||||
|
||||
# indices within the tile - use nomenclature directly from cublas docs
|
||||
outer = i % 128 # "outer" in cublas docs
|
||||
inner = j % 4 # "inner" in cublas docs
|
||||
|
||||
# Note: "offset" is given in terms of bytes, in cublas docs, but our scales are e8m0,
|
||||
# anyway, and so 1B == 1 value => use offset directly.
|
||||
# Formula directly from cublas docs in 3.1.4.3.2
|
||||
offset = tile_offset + (outer % 32) * 16 + (outer // 32) * 4 + inner
|
||||
|
||||
output_scales[i, j] = input_scales[offset]
|
||||
|
||||
return output_scales
|
||||
|
||||
def from_blocked_format(x_mxfp8, scales_unswizzled, blocksize=32):
|
||||
# expand scales
|
||||
scales = torch.repeat_interleave(scales_unswizzled, blocksize, dim=1)
|
||||
|
||||
# de-scale and convert
|
||||
x_f32 = x_mxfp8.to(torch.float) * scales.to(torch.float)
|
||||
return x_f32.to(torch.bfloat16)
|
||||
|
||||
def to_blocked(input_matrix) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
|
||||
@ -950,6 +950,13 @@ def prof_meth_call(*args, **kwargs):
|
||||
torch._C.ScriptFunction.__call__ = prof_func_call # type: ignore[method-assign]
|
||||
torch._C.ScriptMethod.__call__ = prof_meth_call # type: ignore[method-assign]
|
||||
|
||||
def _get_test_report_path():
|
||||
# allow users to override the test file location. We need this
|
||||
# because the distributed tests run the same test file multiple
|
||||
# times with different configurations.
|
||||
override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
|
||||
test_source = override if override is not None else 'python-unittest'
|
||||
return os.path.join('test-reports', test_source)
|
||||
|
||||
def parse_cmd_line_args():
|
||||
global CI_FUNCTORCH_ROOT
|
||||
@ -980,7 +987,9 @@ def parse_cmd_line_args():
|
||||
parser.add_argument('--repeat', type=int, default=1)
|
||||
parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
|
||||
parser.add_argument('--use-pytest', action='store_true')
|
||||
parser.add_argument('--save-xml', type=str)
|
||||
parser.add_argument('--save-xml', nargs='?', type=str,
|
||||
const=_get_test_report_path(),
|
||||
default=_get_test_report_path() if IS_CI else None)
|
||||
parser.add_argument('--discover-tests', action='store_true')
|
||||
parser.add_argument('--log-suffix', type=str, default="")
|
||||
parser.add_argument('--run-parallel', type=int, default=1)
|
||||
@ -1010,9 +1019,6 @@ def parse_cmd_line_args():
|
||||
# infer flags based on the default settings
|
||||
GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
|
||||
|
||||
if args.save_xml is None and IS_CI:
|
||||
args.xml_dir = get_report_dir(sys.argv[0], args.log_suffix, args.use_pytest)
|
||||
|
||||
RERUN_DISABLED_TESTS = args.rerun_disabled_tests
|
||||
|
||||
SLOW_TESTS_FILE = args.import_slow_tests
|
||||
@ -1185,37 +1191,19 @@ def lint_test_case_extension(suite):
|
||||
return succeed
|
||||
|
||||
|
||||
def get_report_dir(test_name: str, log_suffix: Optional[str], is_pytest: bool) -> str:
|
||||
"""Generates a test report directory path. Test name does not need to be
|
||||
sanitized."""
|
||||
# total path = test-reports/test_source+log_suffix/test_filename
|
||||
# Base path
|
||||
test_source = "python-unittest"
|
||||
if is_pytest:
|
||||
test_source = "python-pytest"
|
||||
# allow users to override the test file location. We need this
|
||||
# because the distributed tests run the same test file multiple
|
||||
# times with different configurations.
|
||||
override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
|
||||
if override is not None:
|
||||
test_source = override
|
||||
|
||||
# Add log suffix to if provided
|
||||
if log_suffix and log_suffix != "":
|
||||
test_source = test_source + log_suffix
|
||||
|
||||
test_report_dir = os.path.join('test-reports', test_source)
|
||||
|
||||
# Add test file name to path
|
||||
test_filename = sanitize_test_filename(test_name)
|
||||
test_report_dir = os.path.join(test_report_dir, test_filename)
|
||||
|
||||
os.makedirs(test_report_dir, exist_ok=True)
|
||||
return test_report_dir
|
||||
|
||||
|
||||
def get_report_path(report_dir: str, test_filename: str) -> str:
|
||||
return os.path.join(report_dir, f"{sanitize_test_filename(test_filename)}-{os.urandom(8).hex()}.xml")
|
||||
def get_report_path(argv=None, pytest=False):
|
||||
if argv is None:
|
||||
argv = UNITTEST_ARGS
|
||||
test_filename = sanitize_test_filename(argv[0])
|
||||
test_report_path = TEST_SAVE_XML + LOG_SUFFIX
|
||||
test_report_path = os.path.join(test_report_path, test_filename)
|
||||
if pytest:
|
||||
test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
|
||||
os.makedirs(test_report_path, exist_ok=True)
|
||||
test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
|
||||
return test_report_path
|
||||
os.makedirs(test_report_path, exist_ok=True)
|
||||
return test_report_path
|
||||
|
||||
|
||||
def sanitize_pytest_xml(xml_file: str):
|
||||
@ -1358,7 +1346,7 @@ def run_tests(argv=None):
|
||||
pytest_args = argv + ["--use-main-module"]
|
||||
test_report_path = ""
|
||||
if TEST_SAVE_XML:
|
||||
test_report_path = get_report_path(TEST_SAVE_XML, argv[0])
|
||||
test_report_path = get_report_path(pytest=True)
|
||||
print(f'Test results will be stored in {test_report_path}')
|
||||
pytest_args.append(f'--junit-xml-reruns={test_report_path}')
|
||||
if PYTEST_SINGLE_TEST:
|
||||
@ -1402,7 +1390,7 @@ def run_tests(argv=None):
|
||||
def printErrors(self) -> None:
|
||||
super().printErrors()
|
||||
self.printErrorList("XPASS", self.unexpectedSuccesses)
|
||||
test_report_path = get_report_path(TEST_SAVE_XML, argv[0])
|
||||
test_report_path = get_report_path()
|
||||
verbose = '--verbose' in argv or '-v' in argv
|
||||
if verbose:
|
||||
print(f'Test results will be stored in {test_report_path}')
|
||||
|
||||
@ -7,7 +7,8 @@ import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, TypeVar
|
||||
from typing import Any, cast, overload, TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
|
||||
@ -158,7 +159,12 @@ class _DecoratorContextManager:
|
||||
class _NoParamDecoratorContextManager(_DecoratorContextManager):
|
||||
"""Allow a context manager to be used as a decorator without parentheses."""
|
||||
|
||||
def __new__(cls, orig_func=None):
|
||||
@overload
|
||||
def __new__(cls, orig_func: F) -> F: ... # type: ignore[misc]
|
||||
@overload
|
||||
def __new__(cls, orig_func: None = None) -> Self: ...
|
||||
|
||||
def __new__(cls, orig_func: F | None = None) -> Self | F: # type: ignore[misc]
|
||||
if orig_func is None:
|
||||
return super().__new__(cls)
|
||||
return cls()(orig_func)
|
||||
|
||||
@ -311,7 +311,11 @@ def escape(n):
|
||||
|
||||
|
||||
def is_cuda_tensor(obj):
|
||||
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
|
||||
return (
|
||||
isinstance(obj, torch.Tensor) and
|
||||
obj.device.type == "cuda" and
|
||||
not isinstance(obj, torch._subclasses.FakeTensor)
|
||||
)
|
||||
|
||||
def cuda_allocation_context():
|
||||
snapshot = torch.cuda.memory._snapshot()
|
||||
|
||||
Reference in New Issue
Block a user