mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "Update fused kernels and call _safe_softmax from SDPA (#131863)"
This reverts commit caba37e99b03d2199848197de4e452b78c8c2a23. Reverted https://github.com/pytorch/pytorch/pull/131863 on behalf of https://github.com/izaitsevfb due to breaks executorch test executorch/backends/apple/coreml:test - test_vit_skip_conv (executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner) ([comment](https://github.com/pytorch/pytorch/pull/131863#issuecomment-2291855634))
This commit is contained in:
@ -452,15 +452,9 @@ void cpu_flash_attention(
|
||||
dst_data,
|
||||
headSize);
|
||||
}
|
||||
|
||||
// dst <- dst / sum[row]
|
||||
// reorder MHA output with strides
|
||||
for (int64_t row = 0; row < qBlockSize; ++row) {
|
||||
// Row sums for full masked out rows are 0, we set them to 1
|
||||
// in order to avoid NaNs in the output and instead set fully
|
||||
// masked out rows to 0
|
||||
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
|
||||
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
|
||||
accum_t sum_reciprocal = 1 / qk_sum_data[row];
|
||||
vec::map<scalar_t>(
|
||||
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
|
||||
|
@ -8890,7 +8890,6 @@
|
||||
variants: method, function
|
||||
dispatch:
|
||||
QuantizedCPU: eq_quantized_cpu
|
||||
NestedTensorCPU, NestedTensorCUDA: eq_tensor_nested
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
@ -322,14 +322,5 @@ Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) {
|
||||
});
|
||||
}
|
||||
|
||||
Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value.");
|
||||
return NestedTensor_elementwise_Tensor(
|
||||
self, other, "eq", false /*supports_striding*/,
|
||||
[](const Tensor& b1, const Tensor& b2) {
|
||||
return b1.eq(b2);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -647,11 +647,9 @@ Tensor _safe_softmax(
|
||||
int64_t dim,
|
||||
std::optional<ScalarType> dtype) {
|
||||
auto out = at::softmax(self, dim, dtype);
|
||||
const auto neg_inf = at::scalar_tensor(-std::numeric_limits<float>::infinity(), at::TensorOptions().dtype(out.dtype()).device(out.device()));
|
||||
const auto masked = self.eq(neg_inf);
|
||||
const auto masked = self.eq(-std::numeric_limits<float>::infinity());
|
||||
const auto masked_rows = all(masked, dim, true);
|
||||
const auto zero = at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device()));
|
||||
return at::where(masked_rows, zero, out);
|
||||
return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out);
|
||||
}
|
||||
// Computes scaled dot product attention on query, key and value tensors, using
|
||||
// an optional attention mask if passed, and applying dropout if a probability
|
||||
@ -839,7 +837,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
attn.add_(*attn_mask);
|
||||
}
|
||||
}
|
||||
attn = at::_safe_softmax(attn, -1);
|
||||
attn = at::softmax(attn, -1);
|
||||
if (dropout_p > 0.0) {
|
||||
if (dropout_mask.has_value()) {
|
||||
// In order to validate the correctness of the fused kernels, we need to
|
||||
|
@ -144,10 +144,7 @@ class MemoryEfficientAttentionNormalize {
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
// Row sums for full masked out rows are 0, we set them to 1
|
||||
// In order to avoid NaNs in the output and instead sem them to 0.
|
||||
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
|
||||
ElementCompute alpha = isLast ? (1 / denom) : 1;
|
||||
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
|
||||
ElementCompute beta = alpha * m_prime_[row];
|
||||
|
||||
intermediate = mul_add_source(beta, converted_source); // X = beta * C
|
||||
@ -177,10 +174,7 @@ class MemoryEfficientAttentionNormalize {
|
||||
ComputeFragment intermediate;
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
// Row sums for full masked out rows are 0, we set them to 1
|
||||
// In order to avoid NaNs in the output and instead sem them to 0.
|
||||
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
|
||||
ElementCompute alpha = isLast ? (1 / denom) : 1;
|
||||
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
|
||||
|
||||
intermediate = mul_accumulator(
|
||||
alpha, converted_accumulator); // X = alpha * C + uniform
|
||||
|
@ -1166,10 +1166,6 @@ struct AttentionKernel {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (thread_id() < p.num_queries) {
|
||||
// We set fully masked out rows to 0, the sumexp for masked out rows will be 0
|
||||
// We update it to be 1 prior to calling log so that log(1) = 0
|
||||
s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()];
|
||||
mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits<accum_t>::infinity()) ? 0: mi[thread_id()];
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
|
@ -1791,6 +1791,9 @@ class TestOperators(TestCase):
|
||||
), # NYI: forward-AD for soft_margin_loss_backward
|
||||
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
|
||||
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
|
||||
xfail(
|
||||
"torch.ops.aten._safe_softmax.default"
|
||||
), # NYI: forward-AD for _safe_softmax
|
||||
skip("nn.functional.scaled_dot_product_attention"),
|
||||
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
||||
xfail(
|
||||
@ -1973,6 +1976,9 @@ class TestOperators(TestCase):
|
||||
xfail(
|
||||
"nn.functional.ctc_loss"
|
||||
), # ForwardAD not implemented and no decomposition
|
||||
xfail(
|
||||
"torch.ops.aten._safe_softmax.default"
|
||||
), # ForwardAD not implemented
|
||||
xfail("nn.functional.dropout2d"), # calls random op
|
||||
xfail("nn.functional.dropout3d"), # calls random op
|
||||
xfail("nn.functional.dropout"), # calls random op
|
||||
|
@ -12385,21 +12385,13 @@ if __name__ == '__main__':
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
self.assertEqual(result.shape, ref_output.shape)
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
# 1 values are masked. Since there is only 1 input embedding this
|
||||
# will result in nan.
|
||||
mask = torch.tensor([[1]], device=device) == 1
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
fast_path_device = result.is_cuda or result.is_cpu
|
||||
result = result.cpu().detach().numpy()
|
||||
# Non Fast Paths
|
||||
if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device:
|
||||
# We changed the semenatic, on the non fast path so that fully masked out rows return
|
||||
# 0 from attention thus NaNs should no longer be present and the output should be nonzero
|
||||
# due to skip connections
|
||||
self.assertTrue(not np.isnan(result).any())
|
||||
else:
|
||||
# Fast Paths
|
||||
self.assertTrue(np.isnan(result).all())
|
||||
|
||||
|
||||
# deterministic input
|
||||
encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
|
||||
[[5., 6., 7., 8.]]], device=device, dtype=dtype))
|
||||
|
@ -347,7 +347,6 @@ class TestTransformers(NNTestCase):
|
||||
@parametrize("key_padding_mask_dim", [2, None])
|
||||
@parametrize("mask_dtype", [torch.bool, torch.float32])
|
||||
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
|
||||
# MHA converts all
|
||||
with torch.no_grad():
|
||||
B = 2
|
||||
L = 4
|
||||
@ -357,7 +356,7 @@ class TestTransformers(NNTestCase):
|
||||
if attn_mask_dim == 2:
|
||||
attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
|
||||
elif attn_mask_dim == 3:
|
||||
attn_mask = make_tensor((B, 1, L, L), dtype=mask_dtype, device=device).expand(B, H, L, L).reshape(B * H, L, L)
|
||||
attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
|
||||
elif attn_mask_dim is None:
|
||||
attn_mask = None
|
||||
|
||||
@ -373,9 +372,7 @@ class TestTransformers(NNTestCase):
|
||||
out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
|
||||
mha.eval() # enable fast path
|
||||
out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
|
||||
# The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead
|
||||
# of NaNs for fully masked rows
|
||||
torch.testing.assert_close(out, out_fp.nan_to_num())
|
||||
self.assertEqual(out, out_fp)
|
||||
|
||||
@parametrize("nhead", [1, 4, 8])
|
||||
def test_transformerencoderlayer_src_mask(self, device, nhead):
|
||||
@ -1159,25 +1156,6 @@ class TestTransformers(NNTestCase):
|
||||
else:
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask, dropout_p, is_causal)
|
||||
# This test the fully masked out rows case
|
||||
if torch.isnan(expected).any():
|
||||
row_sums = attn_mask.sum(dim=-1)
|
||||
masked_out_rows = (row_sums == 0)
|
||||
|
||||
for _ in range((input_dim - attn_mask_dim) - 1):
|
||||
masked_out_rows = masked_out_rows.unsqueeze(0)
|
||||
|
||||
masked_out_rows = masked_out_rows.expand(expected.shape[:-1])
|
||||
# Slice out the fully masked rows from expected and actual
|
||||
expected_masked_out = expected[masked_out_rows]
|
||||
actual_masked_out = actual[masked_out_rows]
|
||||
|
||||
expected_all_nan = torch.isnan(expected_masked_out).all()
|
||||
actual_all_zero = (actual_masked_out.abs().sum() == 0)
|
||||
|
||||
self.assertTrue(expected_all_nan)
|
||||
self.assertTrue(actual_all_zero)
|
||||
return
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@ -1983,7 +1961,7 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
@parametrize("n_head", [1, 3])
|
||||
@parametrize("head_dim", [8])
|
||||
@parametrize("mask_dim", [2, 4])
|
||||
@parametrize("bool_mask", [False, True])
|
||||
@parametrize("bool_mask", [0, 1])
|
||||
@parametrize("train", [True, False])
|
||||
@parametrize("casual", [True, False])
|
||||
@parametrize("set_attn_mask", [True, False])
|
||||
@ -2058,9 +2036,6 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
math_ref = math_ref.to(dtype)
|
||||
|
||||
self.assertFalse(torch.isnan(math_ref).any())
|
||||
self.assertFalse(torch.isnan(actual).any())
|
||||
|
||||
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
|
||||
if train:
|
||||
@ -2089,104 +2064,6 @@ class TestSDPACpuOnly(NNTestCase):
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
self.assertEqual(math_ref, actual)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION])
|
||||
@parametrize("seq_len", [32, 64, 128])
|
||||
@parametrize("head_dim", [16, 32])
|
||||
@parametrize("dtype", [torch.float32, torch.float16])
|
||||
def test_fully_masked_out_rows(self, backend, device, seq_len, head_dim, dtype):
|
||||
def attention_inputs(seq_len, head_dim, device, dtype, mask_every_n_rows=4):
|
||||
query = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
|
||||
key = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
|
||||
value = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
|
||||
|
||||
# Create a mask with deterministic row masking
|
||||
mask = torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Mask every nth row
|
||||
mask[0, 0, ::mask_every_n_rows, :] = False
|
||||
|
||||
# Create a fixed pattern for element-wise masking
|
||||
element_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
|
||||
element_mask[torch.arange(seq_len)[:, None] % 5 == torch.arange(seq_len) % 5] = True
|
||||
|
||||
# Combine row masking and element-wise masking
|
||||
mask = mask & element_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
return query, key, value, mask
|
||||
|
||||
def compute_output_and_grads(query, key, value, mask, backend):
|
||||
with sdpa_kernel(backend):
|
||||
masked_out = scaled_dot_product_attention(query, key, value, attn_mask=mask)
|
||||
loss = masked_out.sum()
|
||||
grads = torch.autograd.grad(loss, [query, key, value])
|
||||
return masked_out, grads
|
||||
|
||||
if backend == SDPBackend.FLASH_ATTENTION and "cuda" in str(device):
|
||||
unittest.skip("FlashAttention does not support masks on cuda")
|
||||
return
|
||||
if backend == SDPBackend.EFFICIENT_ATTENTION and "cpu" in str(device):
|
||||
unittest.skip("EfficientAttention does not support masks on cpu")
|
||||
return
|
||||
query, key, value, mask = attention_inputs(seq_len, head_dim, device, dtype)
|
||||
|
||||
# Compute results for the tested backend
|
||||
backend_out, backend_grads = compute_output_and_grads(query, key, value, mask, backend)
|
||||
|
||||
# Compute results for the Math backend
|
||||
math_out, math_grads = compute_output_and_grads(query, key, value, mask, SDPBackend.MATH)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(backend_out, math_out, atol=5e-3, rtol=0)
|
||||
self.assertFalse(backend_out.isnan().any())
|
||||
self.assertFalse(math_out.isnan().any())
|
||||
# Compare gradients
|
||||
for bg, mg in zip(backend_grads, math_grads):
|
||||
torch.testing.assert_close(bg, mg, atol=3e-3, rtol=0)
|
||||
self.assertFalse(bg.isnan().any())
|
||||
self.assertFalse(mg.isnan().any())
|
||||
|
||||
# Check if masked rows are zero in output
|
||||
mask_sum = mask.sum(dim=-1, keepdim=True)
|
||||
masked_rows = (mask_sum == 0).expand_as(backend_out)
|
||||
self.assertTrue((mask_sum == 0).sum() > 0, "No fully masked out rows found")
|
||||
assert torch.all(backend_out[masked_rows] == 0), \
|
||||
f"Non-zero values in fully masked rows for {backend=}"
|
||||
|
||||
# Check if gradients for masked rows are zero
|
||||
grad_query = backend_grads[0]
|
||||
assert torch.all(grad_query[masked_rows] == 0), f"Non-zero gradients in fully masked rows for {backend=}"
|
||||
|
||||
@parametrize("dtype", [torch.float32, torch.float16])
|
||||
@parametrize("fill_val", [float("inf")])
|
||||
def test_non_masked_rows_nan_props(self, device, dtype, fill_val):
|
||||
query = torch.randn(1, 2, 4, 16, device=device, dtype=dtype)
|
||||
# a single NaN in the query input
|
||||
query[0, 1, 2, 3] = fill_val
|
||||
query = query.detach().requires_grad_(True)
|
||||
key = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
|
||||
value = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
||||
self.assertTrue(torch.isnan(out).any())
|
||||
out.sum().backward()
|
||||
self.assertTrue(torch.isnan(query.grad).any())
|
||||
|
||||
@parametrize("kernel", [SDPBackend.MATH])
|
||||
def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
|
||||
# https://github.com/pytorch/pytorch/issues/105190.
|
||||
def ref(x):
|
||||
v1 = torch.matmul(x, x.transpose(-1, -2))
|
||||
v2 = v1 / -0.0001
|
||||
v3 = v2.softmax(dim=-1)
|
||||
v4 = torch.matmul(v3, x)
|
||||
return v4
|
||||
|
||||
x = torch.randn(1, 3, 64, 64, device=device)
|
||||
ref_result = ref(x)
|
||||
with sdpa_kernel(backends=[kernel]):
|
||||
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
|
||||
self.assertEqual(ref_result, sdp_math)
|
||||
|
||||
class TestSDPACudaOnly(NNTestCase):
|
||||
""" Used to test CUDA only functionality of scaled_dot_product_attention
|
||||
|
@ -2845,7 +2845,6 @@
|
||||
# Transformer
|
||||
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
|
||||
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||
result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true))
|
||||
|
||||
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
output_differentiability: [True, False, False, False]
|
||||
|
@ -6718,22 +6718,6 @@ Tensor logsumexp_jvp(
|
||||
}
|
||||
}
|
||||
|
||||
Tensor safe_logsumexp_jvp(
|
||||
const Tensor& self_p,
|
||||
const Tensor& self_t,
|
||||
IntArrayRef dim,
|
||||
bool keepdim) {
|
||||
auto lse_jvp = logsumexp_jvp(self_p, self_t, dim, keepdim);
|
||||
const auto neg_inf = at::scalar_tensor(
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
|
||||
const auto masked = self_p.eq(neg_inf);
|
||||
const auto masked_rows = all(masked, dim, true);
|
||||
const auto zero = at::scalar_tensor(
|
||||
0.0, at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
|
||||
return at::where(masked_rows, zero, lse_jvp);
|
||||
}
|
||||
|
||||
Tensor warn_backwards(const Tensor& grad_output) {
|
||||
TORCH_WARN("Warn from backward");
|
||||
return grad_output;
|
||||
|
@ -229,11 +229,6 @@ at::Tensor logsumexp_jvp(
|
||||
const at::Tensor& self_t,
|
||||
IntArrayRef dim,
|
||||
bool keepdim);
|
||||
at::Tensor safe_logsumexp_jvp(
|
||||
const at::Tensor& self_p,
|
||||
const at::Tensor& self_t,
|
||||
IntArrayRef dim,
|
||||
bool keepdim);
|
||||
at::Tensor logcumsumexp_backward(
|
||||
at::Tensor grad,
|
||||
const at::Tensor& self,
|
||||
|
@ -16210,8 +16210,8 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_safe_softmax,
|
||||
assert_jit_shape_analysis=True,
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_forward_ad=False,
|
||||
supports_fwgrad_bwgrad=False,
|
||||
supports_out=False,
|
||||
supports_cow_input_no_materialize_backward=False,
|
||||
decorators=[],
|
||||
|
Reference in New Issue
Block a user