[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2025-06-14 00:48:12 +08:00
committed by PyTorch MergeBot
parent 3e38feb05f
commit 596b418391
65 changed files with 640 additions and 475 deletions

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Defines bias subclasses that work with scaled_dot_product_attention"""
from enum import auto, IntEnum
from typing import Optional
from warnings import warn
@ -101,9 +102,15 @@ class CausalBias(torch.Tensor):
# Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
q = torch.randn(
bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16
)
k = torch.randn(
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
)
v = torch.randn(
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
)
out = F.scaled_dot_product_attention(q, k, v, attn_bias)