Revert "Update fused kernels and call _safe_softmax from SDPA (#131863)"

This reverts commit caba37e99b03d2199848197de4e452b78c8c2a23.

Reverted https://github.com/pytorch/pytorch/pull/131863 on behalf of https://github.com/izaitsevfb due to breaks executorch test executorch/backends/apple/coreml:test - test_vit_skip_conv (executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner) ([comment](https://github.com/pytorch/pytorch/pull/131863#issuecomment-2291855634))
This commit is contained in:
PyTorch MergeBot
2024-08-15 17:55:07 +00:00
parent d3b458e603
commit cfec69e2a1
13 changed files with 19 additions and 194 deletions

View File

@ -452,15 +452,9 @@ void cpu_flash_attention(
dst_data,
headSize);
}
// dst <- dst / sum[row]
// reorder MHA output with strides
for (int64_t row = 0; row < qBlockSize; ++row) {
// Row sums for full masked out rows are 0, we set them to 1
// in order to avoid NaNs in the output and instead set fully
// masked out rows to 0
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
accum_t sum_reciprocal = 1 / qk_sum_data[row];
vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },

View File

@ -8890,7 +8890,6 @@
variants: method, function
dispatch:
QuantizedCPU: eq_quantized_cpu
NestedTensorCPU, NestedTensorCUDA: eq_tensor_nested
tags: [core, pointwise]
- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -322,14 +322,5 @@ Tensor eq_scalar_nested(const Tensor& self, const Scalar& other) {
});
}
Tensor eq_tensor_nested(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!other.is_nested(), "eq does not support nested tensor as other value.");
return NestedTensor_elementwise_Tensor(
self, other, "eq", false /*supports_striding*/,
[](const Tensor& b1, const Tensor& b2) {
return b1.eq(b2);
});
}
} // namespace native
} // namespace at

View File

@ -647,11 +647,9 @@ Tensor _safe_softmax(
int64_t dim,
std::optional<ScalarType> dtype) {
auto out = at::softmax(self, dim, dtype);
const auto neg_inf = at::scalar_tensor(-std::numeric_limits<float>::infinity(), at::TensorOptions().dtype(out.dtype()).device(out.device()));
const auto masked = self.eq(neg_inf);
const auto masked = self.eq(-std::numeric_limits<float>::infinity());
const auto masked_rows = all(masked, dim, true);
const auto zero = at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device()));
return at::where(masked_rows, zero, out);
return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out);
}
// Computes scaled dot product attention on query, key and value tensors, using
// an optional attention mask if passed, and applying dropout if a probability
@ -839,7 +837,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
attn.add_(*attn_mask);
}
}
attn = at::_safe_softmax(attn, -1);
attn = at::softmax(attn, -1);
if (dropout_p > 0.0) {
if (dropout_mask.has_value()) {
// In order to validate the correctness of the fused kernels, we need to

View File

@ -144,10 +144,7 @@ class MemoryEfficientAttentionNormalize {
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
// Row sums for full masked out rows are 0, we set them to 1
// In order to avoid NaNs in the output and instead sem them to 0.
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
ElementCompute alpha = isLast ? (1 / denom) : 1;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
@ -177,10 +174,7 @@ class MemoryEfficientAttentionNormalize {
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
// Row sums for full masked out rows are 0, we set them to 1
// In order to avoid NaNs in the output and instead sem them to 0.
ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
ElementCompute alpha = isLast ? (1 / denom) : 1;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform

View File

@ -1166,10 +1166,6 @@ struct AttentionKernel {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (thread_id() < p.num_queries) {
// We set fully masked out rows to 0, the sumexp for masked out rows will be 0
// We update it to be 1 prior to calling log so that log(1) = 0
s_prime[thread_id()] = (s_prime[thread_id()] == 0) ? 1: s_prime[thread_id()];
mi[thread_id()] = (mi[thread_id()] == -cutlass::platform::numeric_limits<accum_t>::infinity()) ? 0: mi[thread_id()];
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {

View File

@ -1791,6 +1791,9 @@ class TestOperators(TestCase):
), # NYI: forward-AD for soft_margin_loss_backward
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
xfail(
"torch.ops.aten._safe_softmax.default"
), # NYI: forward-AD for _safe_softmax
skip("nn.functional.scaled_dot_product_attention"),
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
xfail(
@ -1973,6 +1976,9 @@ class TestOperators(TestCase):
xfail(
"nn.functional.ctc_loss"
), # ForwardAD not implemented and no decomposition
xfail(
"torch.ops.aten._safe_softmax.default"
), # ForwardAD not implemented
xfail("nn.functional.dropout2d"), # calls random op
xfail("nn.functional.dropout3d"), # calls random op
xfail("nn.functional.dropout"), # calls random op

View File

@ -12385,21 +12385,13 @@ if __name__ == '__main__':
result = model(encoder_input, src_key_padding_mask=mask)
self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
# 1 values are masked. Since there is only 1 input embedding this
# will result in nan.
mask = torch.tensor([[1]], device=device) == 1
result = model(encoder_input, src_key_padding_mask=mask)
fast_path_device = result.is_cuda or result.is_cpu
result = result.cpu().detach().numpy()
# Non Fast Paths
if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device:
# We changed the semenatic, on the non fast path so that fully masked out rows return
# 0 from attention thus NaNs should no longer be present and the output should be nonzero
# due to skip connections
self.assertTrue(not np.isnan(result).any())
else:
# Fast Paths
self.assertTrue(np.isnan(result).all())
# deterministic input
encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
[[5., 6., 7., 8.]]], device=device, dtype=dtype))

View File

@ -347,7 +347,6 @@ class TestTransformers(NNTestCase):
@parametrize("key_padding_mask_dim", [2, None])
@parametrize("mask_dtype", [torch.bool, torch.float32])
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
# MHA converts all
with torch.no_grad():
B = 2
L = 4
@ -357,7 +356,7 @@ class TestTransformers(NNTestCase):
if attn_mask_dim == 2:
attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
elif attn_mask_dim == 3:
attn_mask = make_tensor((B, 1, L, L), dtype=mask_dtype, device=device).expand(B, H, L, L).reshape(B * H, L, L)
attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
elif attn_mask_dim is None:
attn_mask = None
@ -373,9 +372,7 @@ class TestTransformers(NNTestCase):
out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
mha.eval() # enable fast path
out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
# The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead
# of NaNs for fully masked rows
torch.testing.assert_close(out, out_fp.nan_to_num())
self.assertEqual(out, out_fp)
@parametrize("nhead", [1, 4, 8])
def test_transformerencoderlayer_src_mask(self, device, nhead):
@ -1159,25 +1156,6 @@ class TestTransformers(NNTestCase):
else:
actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask, dropout_p, is_causal)
# This test the fully masked out rows case
if torch.isnan(expected).any():
row_sums = attn_mask.sum(dim=-1)
masked_out_rows = (row_sums == 0)
for _ in range((input_dim - attn_mask_dim) - 1):
masked_out_rows = masked_out_rows.unsqueeze(0)
masked_out_rows = masked_out_rows.expand(expected.shape[:-1])
# Slice out the fully masked rows from expected and actual
expected_masked_out = expected[masked_out_rows]
actual_masked_out = actual[masked_out_rows]
expected_all_nan = torch.isnan(expected_masked_out).all()
actual_all_zero = (actual_masked_out.abs().sum() == 0)
self.assertTrue(expected_all_nan)
self.assertTrue(actual_all_zero)
return
self.assertEqual(actual, expected)
@ -1983,7 +1961,7 @@ class TestSDPACpuOnly(NNTestCase):
@parametrize("n_head", [1, 3])
@parametrize("head_dim", [8])
@parametrize("mask_dim", [2, 4])
@parametrize("bool_mask", [False, True])
@parametrize("bool_mask", [0, 1])
@parametrize("train", [True, False])
@parametrize("casual", [True, False])
@parametrize("set_attn_mask", [True, False])
@ -2058,9 +2036,6 @@ class TestSDPACpuOnly(NNTestCase):
if dtype in [torch.bfloat16, torch.float16]:
math_ref = math_ref.to(dtype)
self.assertFalse(torch.isnan(math_ref).any())
self.assertFalse(torch.isnan(actual).any())
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
if train:
@ -2089,104 +2064,6 @@ class TestSDPACpuOnly(NNTestCase):
actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
self.assertEqual(math_ref, actual)
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION])
@parametrize("seq_len", [32, 64, 128])
@parametrize("head_dim", [16, 32])
@parametrize("dtype", [torch.float32, torch.float16])
def test_fully_masked_out_rows(self, backend, device, seq_len, head_dim, dtype):
def attention_inputs(seq_len, head_dim, device, dtype, mask_every_n_rows=4):
query = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
key = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
value = torch.rand(1, 1, seq_len, head_dim, requires_grad=True, device=device, dtype=dtype)
# Create a mask with deterministic row masking
mask = torch.ones(1, 1, seq_len, seq_len, dtype=torch.bool, device=device)
# Mask every nth row
mask[0, 0, ::mask_every_n_rows, :] = False
# Create a fixed pattern for element-wise masking
element_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
element_mask[torch.arange(seq_len)[:, None] % 5 == torch.arange(seq_len) % 5] = True
# Combine row masking and element-wise masking
mask = mask & element_mask.unsqueeze(0).unsqueeze(0)
return query, key, value, mask
def compute_output_and_grads(query, key, value, mask, backend):
with sdpa_kernel(backend):
masked_out = scaled_dot_product_attention(query, key, value, attn_mask=mask)
loss = masked_out.sum()
grads = torch.autograd.grad(loss, [query, key, value])
return masked_out, grads
if backend == SDPBackend.FLASH_ATTENTION and "cuda" in str(device):
unittest.skip("FlashAttention does not support masks on cuda")
return
if backend == SDPBackend.EFFICIENT_ATTENTION and "cpu" in str(device):
unittest.skip("EfficientAttention does not support masks on cpu")
return
query, key, value, mask = attention_inputs(seq_len, head_dim, device, dtype)
# Compute results for the tested backend
backend_out, backend_grads = compute_output_and_grads(query, key, value, mask, backend)
# Compute results for the Math backend
math_out, math_grads = compute_output_and_grads(query, key, value, mask, SDPBackend.MATH)
# Compare outputs
torch.testing.assert_close(backend_out, math_out, atol=5e-3, rtol=0)
self.assertFalse(backend_out.isnan().any())
self.assertFalse(math_out.isnan().any())
# Compare gradients
for bg, mg in zip(backend_grads, math_grads):
torch.testing.assert_close(bg, mg, atol=3e-3, rtol=0)
self.assertFalse(bg.isnan().any())
self.assertFalse(mg.isnan().any())
# Check if masked rows are zero in output
mask_sum = mask.sum(dim=-1, keepdim=True)
masked_rows = (mask_sum == 0).expand_as(backend_out)
self.assertTrue((mask_sum == 0).sum() > 0, "No fully masked out rows found")
assert torch.all(backend_out[masked_rows] == 0), \
f"Non-zero values in fully masked rows for {backend=}"
# Check if gradients for masked rows are zero
grad_query = backend_grads[0]
assert torch.all(grad_query[masked_rows] == 0), f"Non-zero gradients in fully masked rows for {backend=}"
@parametrize("dtype", [torch.float32, torch.float16])
@parametrize("fill_val", [float("inf")])
def test_non_masked_rows_nan_props(self, device, dtype, fill_val):
query = torch.randn(1, 2, 4, 16, device=device, dtype=dtype)
# a single NaN in the query input
query[0, 1, 2, 3] = fill_val
query = query.detach().requires_grad_(True)
key = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
value = torch.randn(1, 2, 4, 16, device=device, dtype=dtype, requires_grad=True)
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
self.assertTrue(torch.isnan(out).any())
out.sum().backward()
self.assertTrue(torch.isnan(query.grad).any())
@parametrize("kernel", [SDPBackend.MATH])
def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
# https://github.com/pytorch/pytorch/issues/105190.
def ref(x):
v1 = torch.matmul(x, x.transpose(-1, -2))
v2 = v1 / -0.0001
v3 = v2.softmax(dim=-1)
v4 = torch.matmul(v3, x)
return v4
x = torch.randn(1, 3, 64, 64, device=device)
ref_result = ref(x)
with sdpa_kernel(backends=[kernel]):
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
self.assertEqual(ref_result, sdp_math)
class TestSDPACudaOnly(NNTestCase):
""" Used to test CUDA only functionality of scaled_dot_product_attention

View File

@ -2845,7 +2845,6 @@
# Transformer
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true))
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
output_differentiability: [True, False, False, False]

View File

@ -6718,22 +6718,6 @@ Tensor logsumexp_jvp(
}
}
Tensor safe_logsumexp_jvp(
const Tensor& self_p,
const Tensor& self_t,
IntArrayRef dim,
bool keepdim) {
auto lse_jvp = logsumexp_jvp(self_p, self_t, dim, keepdim);
const auto neg_inf = at::scalar_tensor(
-std::numeric_limits<float>::infinity(),
at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
const auto masked = self_p.eq(neg_inf);
const auto masked_rows = all(masked, dim, true);
const auto zero = at::scalar_tensor(
0.0, at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
return at::where(masked_rows, zero, lse_jvp);
}
Tensor warn_backwards(const Tensor& grad_output) {
TORCH_WARN("Warn from backward");
return grad_output;

View File

@ -229,11 +229,6 @@ at::Tensor logsumexp_jvp(
const at::Tensor& self_t,
IntArrayRef dim,
bool keepdim);
at::Tensor safe_logsumexp_jvp(
const at::Tensor& self_p,
const at::Tensor& self_t,
IntArrayRef dim,
bool keepdim);
at::Tensor logcumsumexp_backward(
at::Tensor grad,
const at::Tensor& self,

View File

@ -16210,8 +16210,8 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_safe_softmax,
assert_jit_shape_analysis=True,
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
supports_out=False,
supports_cow_input_no_materialize_backward=False,
decorators=[],