mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3e38feb05f
commit
596b418391
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Defines bias subclasses that work with scaled_dot_product_attention"""
|
||||
|
||||
from enum import auto, IntEnum
|
||||
from typing import Optional
|
||||
from warnings import warn
|
||||
@ -101,9 +102,15 @@ class CausalBias(torch.Tensor):
|
||||
# Create a lower-right causal bias
|
||||
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
|
||||
|
||||
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
|
||||
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
||||
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
||||
q = torch.randn(
|
||||
bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16
|
||||
)
|
||||
k = torch.randn(
|
||||
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
|
||||
)
|
||||
v = torch.randn(
|
||||
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
|
||||
)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user